mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'UI2' of https://github.com/ebolam/KoboldAI into UI2
This commit is contained in:
270
aiserver.py
270
aiserver.py
@@ -18,7 +18,8 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
from eventlet import tpool
|
||||
|
||||
import logging
|
||||
logging.basicConfig(format='%(levelname)s - %(module)s:%(lineno)d - %(message)s',level=logging.WARNING)
|
||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
from os import path, getcwd
|
||||
@@ -71,11 +72,12 @@ try:
|
||||
except:
|
||||
pass
|
||||
import transformers.generation_utils
|
||||
|
||||
global tpu_mtj_backend
|
||||
|
||||
|
||||
if lupa.LUA_VERSION[:2] != (5, 4):
|
||||
print(f"Please install lupa==1.10. You have lupa {lupa.__version__}.", file=sys.stderr)
|
||||
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
|
||||
|
||||
patch_causallm_patched = False
|
||||
|
||||
@@ -234,7 +236,8 @@ class Send_to_socketio(object):
|
||||
print(bar, end="")
|
||||
time.sleep(0.01)
|
||||
try:
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", " ")}, broadcast=True, room="UI_1")
|
||||
gui_msg = bar.replace(f"{colors.PURPLE}INIT{colors.END} | ","").replace(" ", " ")
|
||||
emit('from_server', {'cmd': 'model_load_status', 'data': gui_msg}, broadcast=True, room="UI_1")
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -429,8 +432,6 @@ import logging
|
||||
log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
|
||||
# Start flask & SocketIO
|
||||
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
|
||||
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session, jsonify, abort, redirect, has_request_context
|
||||
from flask_socketio import SocketIO, emit, join_room, leave_room
|
||||
from flask_socketio import emit as _emit
|
||||
@@ -442,14 +443,12 @@ app = Flask(__name__, root_path=os.getcwd())
|
||||
app.secret_key = secrets.token_hex()
|
||||
app.config['SESSION_TYPE'] = 'filesystem'
|
||||
app.config['TEMPLATES_AUTO_RELOAD'] = True
|
||||
Session(app)
|
||||
socketio = SocketIO(app, async_method="eventlet", manage_session=False, cors_allowed_origins='*', max_http_buffer_size=10_000_000)
|
||||
#socketio = SocketIO(app, async_method="eventlet", manage_session=False, cors_allowed_origins='*', logger=True, engineio_logger=True)
|
||||
koboldai_vars = koboldai_settings.koboldai_vars(session, socketio)
|
||||
|
||||
utils.koboldai_vars = koboldai_vars
|
||||
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
old_socketio_on = socketio.on
|
||||
def new_socketio_on(*a, **k):
|
||||
@@ -673,7 +672,7 @@ def get_config_filename(model_name = None):
|
||||
elif koboldai_vars.configname != '':
|
||||
return(f"settings/{koboldai_vars.configname.replace('/', '_')}.settings")
|
||||
else:
|
||||
print(f"Empty configfile name sent back. Defaulting to ReadOnly")
|
||||
logger.warning(f"Empty configfile name sent back. Defaulting to ReadOnly")
|
||||
return(f"settings/ReadOnly.settings")
|
||||
#==================================================================#
|
||||
# Function to get model selection at startup
|
||||
@@ -845,14 +844,14 @@ def device_config(config):
|
||||
breakmodel.disk_blocks = args.breakmodel_disklayers
|
||||
n_layers -= args.breakmodel_disklayers
|
||||
except:
|
||||
print("WARNING: --breakmodel_gpulayers is malformatted. Please use the --help option to see correct usage of --breakmodel_gpulayers. Defaulting to all layers on device 0.", file=sys.stderr)
|
||||
logger.warning("--breakmodel_gpulayers is malformatted. Please use the --help option to see correct usage of --breakmodel_gpulayers. Defaulting to all layers on device 0.")
|
||||
breakmodel.gpu_blocks = [n_layers]
|
||||
n_layers = 0
|
||||
elif(args.breakmodel_layers is not None):
|
||||
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
|
||||
n_layers -= sum(breakmodel.gpu_blocks)
|
||||
elif(args.model is not None):
|
||||
print("Breakmodel not specified, assuming GPU 0")
|
||||
logger.info("Breakmodel not specified, assuming GPU 0")
|
||||
breakmodel.gpu_blocks = [n_layers]
|
||||
n_layers = 0
|
||||
else:
|
||||
@@ -911,7 +910,7 @@ def device_config(config):
|
||||
else:
|
||||
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
|
||||
|
||||
print(colors.PURPLE + "\nFinal device configuration:")
|
||||
logger.init_ok("Final device configuration:", status="Info")
|
||||
device_list(n_layers)
|
||||
|
||||
# If all layers are on the same device, use the old GPU generation mode
|
||||
@@ -924,7 +923,7 @@ def device_config(config):
|
||||
return
|
||||
|
||||
if(not breakmodel.gpu_blocks):
|
||||
print("Nothing assigned to a GPU, reverting to CPU only mode")
|
||||
logger.warning("Nothing assigned to a GPU, reverting to CPU only mode")
|
||||
import breakmodel
|
||||
breakmodel.primary_device = "cpu"
|
||||
koboldai_vars.breakmodel = False
|
||||
@@ -1097,7 +1096,7 @@ def savesettings():
|
||||
#==================================================================#
|
||||
@debounce(2)
|
||||
def settingschanged():
|
||||
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
||||
logger.info("Saving settings.")
|
||||
savesettings()
|
||||
|
||||
#==================================================================#
|
||||
@@ -1223,6 +1222,9 @@ def general_startup(override_args=None):
|
||||
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
|
||||
parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks")
|
||||
parser.add_argument("--no_ui", action='store_true', default=False, help="Disables the GUI and Socket.IO server while leaving the API server running.")
|
||||
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
|
||||
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
|
||||
|
||||
#args: argparse.Namespace = None
|
||||
if "pytest" in sys.modules and override_args is None:
|
||||
args = parser.parse_args([])
|
||||
@@ -1257,6 +1259,8 @@ def general_startup(override_args=None):
|
||||
setattr(args, arg, False)
|
||||
else:
|
||||
setattr(args, arg, os.environ[arg])
|
||||
set_logger_verbosity(args.verbosity)
|
||||
quiesce_logger(args.quiesce)
|
||||
if args.customsettings:
|
||||
f = open (args.customsettings)
|
||||
importedsettings = json.load(f)
|
||||
@@ -1324,9 +1328,10 @@ def general_startup(override_args=None):
|
||||
koboldai_vars.model = "NeoCustom"
|
||||
koboldai_vars.custmodpth = modpath
|
||||
elif args.model:
|
||||
print("Welcome to KoboldAI!\nYou have selected the following Model:", koboldai_vars.model)
|
||||
logger.message(f"Welcome to KoboldAI!")
|
||||
logger.message(f"You have selected the following Model: {koboldai_vars.model}")
|
||||
if args.path:
|
||||
print("You have selected the following path for your Model :", args.path)
|
||||
logger.message(f"You have selected the following path for your Model: {args.path}")
|
||||
koboldai_vars.custmodpth = args.path;
|
||||
koboldai_vars.colaburl = args.path + "/request"; # Lets just use the same parameter to keep it simple
|
||||
|
||||
@@ -1492,7 +1497,7 @@ def get_oai_models(data):
|
||||
return
|
||||
|
||||
# Get list of models from OAI
|
||||
print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
logger.init("OAI Engines", status="Retrieving")
|
||||
req = requests.get(
|
||||
url,
|
||||
headers = {
|
||||
@@ -1504,7 +1509,7 @@ def get_oai_models(data):
|
||||
try:
|
||||
engines = [[en["id"], "{} ({})".format(en['id'], "Ready" if en["ready"] == True else "Not Ready")] for en in engines]
|
||||
except:
|
||||
print(engines)
|
||||
logger.error(engines)
|
||||
raise
|
||||
|
||||
online_model = ""
|
||||
@@ -1530,12 +1535,13 @@ def get_oai_models(data):
|
||||
js["apikey"] = key
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
logger.init_ok("OAI Engines", status="OK")
|
||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True, room="UI_1")
|
||||
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
|
||||
else:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
||||
print(req.json())
|
||||
logger.init_err("OAI Engines", status="Failed")
|
||||
logger.error(req.json())
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
|
||||
|
||||
@socketio.on("get_cluster_models")
|
||||
@@ -1544,11 +1550,15 @@ def get_cluster_models(msg):
|
||||
koboldai_vars.apikey = koboldai_vars.oaiapikey
|
||||
model = msg['model']
|
||||
url = msg['url']
|
||||
|
||||
|
||||
# Get list of models from public cluster
|
||||
print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
req = requests.get("{}/models".format(url))
|
||||
try:
|
||||
req = requests.get("{}/models".format(url))
|
||||
except:
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error("Provided KoboldAI Horde URL unreachable")
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
|
||||
return
|
||||
if(req.status_code == 200):
|
||||
engines = req.json()
|
||||
print(engines)
|
||||
@@ -1587,9 +1597,46 @@ def get_cluster_models(msg):
|
||||
emit('oai_engines', {'data': engines, 'online_model': online_model}, broadcast=False, room="UI_2")
|
||||
else:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
print("{0}ERROR!{1}".format(colors.RED, colors.END))
|
||||
print(req.json())
|
||||
logger.init_err("KAI Horde Models", status="Failed")
|
||||
logger.error(req.json())
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': req.json()}, room="UI_1")
|
||||
return
|
||||
|
||||
engines = req.json()
|
||||
logger.debug(engines)
|
||||
try:
|
||||
engines = [[en, en] for en in engines]
|
||||
except:
|
||||
logger.error(engines)
|
||||
raise
|
||||
|
||||
online_model = ""
|
||||
changed=False
|
||||
|
||||
#Save the key
|
||||
if not path.exists("settings"):
|
||||
# If the client settings file doesn't exist, create it
|
||||
# Write API key to file
|
||||
os.makedirs('settings', exist_ok=True)
|
||||
if path.exists(get_config_filename(koboldai_vars.model_selected)):
|
||||
with open(get_config_filename(koboldai_vars.model_selected), "r") as file:
|
||||
js = json.load(file)
|
||||
if 'online_model' in js:
|
||||
online_model = js['online_model']
|
||||
if "apikey" in js:
|
||||
if js['apikey'] != koboldai_vars.oaiapikey:
|
||||
changed=True
|
||||
else:
|
||||
changed=True
|
||||
if changed:
|
||||
js={}
|
||||
with open(get_config_filename(koboldai_vars.model_selected), "w") as file:
|
||||
js["apikey"] = koboldai_vars.oaiapikey
|
||||
file.write(json.dumps(js, indent=3))
|
||||
|
||||
logger.init_ok("KAI Horde Models", status="OK")
|
||||
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
|
||||
|
||||
|
||||
# Function to patch transformers to use our soft prompt
|
||||
def patch_causallm(model):
|
||||
@@ -1637,11 +1684,11 @@ def patch_transformers_download():
|
||||
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: transformers.utils.hub.BinaryIO,
|
||||
temp_file,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
headers: transformers.utils.hub.Optional[transformers.utils.hub.Dict[str, str]] = None,
|
||||
file_name: transformers.utils.hub.Optional[str] = None,
|
||||
headers=None,
|
||||
file_name=None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
@@ -2249,29 +2296,28 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
elif(koboldai_vars.model_type == "not_found" and koboldai_vars.model == "GPT2Custom"):
|
||||
koboldai_vars.model_type = "gpt2"
|
||||
elif(koboldai_vars.model_type == "not_found"):
|
||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
logger.warning("No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
koboldai_vars.model_type = "gpt_neo"
|
||||
|
||||
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print(2)
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
logger.init("GPU support", status="Searching")
|
||||
koboldai_vars.hascuda = torch.cuda.is_available()
|
||||
koboldai_vars.bmsupported = (utils.HAS_ACCELERATE or koboldai_vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not koboldai_vars.nobreakmodel
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
||||
logger.warning("--breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).")
|
||||
if(args.breakmodel_layers is not None):
|
||||
print("WARNING: --breakmodel_layers is deprecated. Use --breakmodel_gpulayers instead (see --help for details).", file=sys.stderr)
|
||||
logger.warning("--breakmodel_layers is deprecated. Use --breakmodel_gpulayers instead (see --help for details).")
|
||||
if(args.model and koboldai_vars.bmsupported and not args.breakmodel_gpulayers and not args.breakmodel_layers and (not utils.HAS_ACCELERATE or not args.breakmodel_disklayers)):
|
||||
print("WARNING: Model launched without the --breakmodel_gpulayers argument, defaulting to GPU only mode.", file=sys.stderr)
|
||||
logger.warning("Model launched without the --breakmodel_gpulayers argument, defaulting to GPU only mode.")
|
||||
koboldai_vars.bmsupported = False
|
||||
if(not koboldai_vars.bmsupported and (args.breakmodel_gpulayers is not None or args.breakmodel_layers is not None or args.breakmodel_disklayers is not None)):
|
||||
print("WARNING: This model does not support hybrid generation. --breakmodel_gpulayers will be ignored.", file=sys.stderr)
|
||||
logger.warning("This model does not support hybrid generation. --breakmodel_gpulayers will be ignored.")
|
||||
if(koboldai_vars.hascuda):
|
||||
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
|
||||
logger.init_ok("GPU support", status="Found")
|
||||
else:
|
||||
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
|
||||
logger.init_warn("GPU support", status="Not Found")
|
||||
|
||||
if args.cpu:
|
||||
koboldai_vars.usegpu = False
|
||||
@@ -2308,7 +2354,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
# Start transformers and create pipeline
|
||||
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
if(not koboldai_vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
logger.init("Transformers", status='Starting')
|
||||
for m in ("GPTJModel", "XGLMModel"):
|
||||
try:
|
||||
globals()[m] = getattr(__import__("transformers"), m)
|
||||
@@ -2376,7 +2422,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
print(flush=True)
|
||||
koboldai_vars.total_layers = num_tensors
|
||||
koboldai_vars.loaded_layers = 0
|
||||
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors", file=Send_to_socketio())
|
||||
utils.bar = tqdm(total=num_tensors, desc=f"{colors.PURPLE}INIT{colors.END} | Loading model tensors", file=Send_to_socketio())
|
||||
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
@@ -2447,7 +2493,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
|
||||
def maybe_low_cpu_mem_usage() -> Dict[str, Any]:
|
||||
if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")):
|
||||
print(f"\nWARNING: Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers {transformers_version}.", file=sys.stderr)
|
||||
logger.warning(f"Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers {transformers_version}.")
|
||||
return {}
|
||||
return {"low_cpu_mem_usage": True}
|
||||
|
||||
@@ -2504,7 +2550,6 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
if os.path.isdir(koboldai_vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
shutil.move(koboldai_vars.model.replace('/', '_'), "models/{}".format(koboldai_vars.model.replace('/', '_')))
|
||||
print("\n", flush=True)
|
||||
if(koboldai_vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
|
||||
try:
|
||||
@@ -2597,11 +2642,13 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
import transformers.configuration_utils
|
||||
import transformers.modeling_utils
|
||||
import transformers.file_utils
|
||||
import huggingface_hub
|
||||
legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0")
|
||||
# Save the config.json
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(koboldai_vars.model, transformers.configuration_utils.CONFIG_NAME, revision=koboldai_vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
|
||||
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(koboldai_vars.model, transformers.configuration_utils.CONFIG_NAME, revision=koboldai_vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
|
||||
if(utils.num_shards is None):
|
||||
# Save the pytorch_model.bin of an unsharded model
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(koboldai_vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=koboldai_vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
|
||||
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(koboldai_vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=koboldai_vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
|
||||
else:
|
||||
with open(utils.from_pretrained_index_filename) as f:
|
||||
map_data = json.load(f)
|
||||
@@ -2610,7 +2657,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
shutil.move(utils.from_pretrained_index_filename, os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
|
||||
# Then save the pytorch_model-#####-of-#####.bin files
|
||||
for filename in filenames:
|
||||
shutil.move(transformers.file_utils.get_from_cache(transformers.file_utils.hf_bucket_url(koboldai_vars.model, filename, revision=koboldai_vars.revision), cache_dir="cache", local_files_only=True), os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), filename))
|
||||
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(koboldai_vars.model, filename, revision=koboldai_vars.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(koboldai_vars.model.replace('/', '_')), filename))
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
if(koboldai_vars.badwordsids is koboldai_settings.badwordsids_default and koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")):
|
||||
@@ -2652,7 +2699,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||
#for key in koboldai_vars.badwords:
|
||||
# koboldai_vars.badwordsids.append([vocab[key]])
|
||||
|
||||
print("{0}OK! {1} pipeline created!{2}".format(colors.GREEN, koboldai_vars.model, colors.END))
|
||||
logger.info(f"Pipeline created: {koboldai_vars.model}")
|
||||
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
@@ -2955,7 +3002,7 @@ def lua_startup():
|
||||
#==================================================================#
|
||||
|
||||
print("", end="", flush=True)
|
||||
print(colors.PURPLE + "Initializing Lua Bridge... " + colors.END, end="", flush=True)
|
||||
logger.init("LUA bridge", status="Starting")
|
||||
|
||||
# Set up Lua state
|
||||
koboldai_vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
|
||||
@@ -2978,11 +3025,11 @@ def lua_startup():
|
||||
except lupa.LuaError as e:
|
||||
print(colors.RED + "ERROR!" + colors.END)
|
||||
koboldai_vars.lua_koboldbridge.obliterate_multiverse()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
exit(1)
|
||||
print(colors.GREEN + "OK!" + colors.END)
|
||||
logger.init_ok("LUA bridge", status="OK")
|
||||
|
||||
|
||||
def lua_log_format_name(name):
|
||||
@@ -3006,7 +3053,7 @@ def load_callback(filename, modulename):
|
||||
# Load all Lua scripts
|
||||
#==================================================================#
|
||||
def load_lua_scripts():
|
||||
print(colors.GREEN + "Loading Core Script" + colors.END)
|
||||
logger.init("LUA Scripts", status="Starting")
|
||||
|
||||
filenames = []
|
||||
modulenames = []
|
||||
@@ -3038,12 +3085,12 @@ def load_lua_scripts():
|
||||
if(koboldai_vars.serverstarted):
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
|
||||
sendUSStatItems()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
if(koboldai_vars.serverstarted):
|
||||
set_aibusy(0)
|
||||
logger.init_ok("LUA Scripts", status="OK")
|
||||
|
||||
#==================================================================#
|
||||
# Print message that originates from the userscript with the given name
|
||||
@@ -3531,9 +3578,8 @@ def execute_inmod():
|
||||
koboldai_vars.lua_running = False
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
|
||||
sendUSStatItems()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
set_aibusy(0)
|
||||
|
||||
@@ -3550,9 +3596,8 @@ def execute_outmod():
|
||||
koboldai_vars.lua_running = False
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
|
||||
sendUSStatItems()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
set_aibusy(0)
|
||||
if(koboldai_vars.lua_koboldbridge.resend_settings_required):
|
||||
@@ -3573,6 +3618,7 @@ def execute_outmod():
|
||||
#==================================================================#
|
||||
@socketio.on('connect')
|
||||
def do_connect():
|
||||
logger.info("Client connected!")
|
||||
if request.args.get("rely") == "true":
|
||||
return
|
||||
join_room("UI_{}".format(request.args.get('ui')))
|
||||
@@ -3633,7 +3679,7 @@ def do_connect():
|
||||
@socketio.on('message')
|
||||
def get_message(msg):
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Data received:{1}{2}".format(colors.GREEN, msg, colors.END))
|
||||
logger.debug(f"Data received: {msg}")
|
||||
# Submit action
|
||||
if(msg['cmd'] == 'submit'):
|
||||
if(koboldai_vars.mode == "play"):
|
||||
@@ -3918,8 +3964,7 @@ def get_message(msg):
|
||||
elif(msg['cmd'] == 'list_model'):
|
||||
sendModelSelection(menu=msg['data'])
|
||||
elif(msg['cmd'] == 'load_model'):
|
||||
print(msg)
|
||||
print(koboldai_vars.model_selected)
|
||||
logger.debug(f"Selected Model: {koboldai_vars.model_selected}")
|
||||
if not os.path.exists("settings/"):
|
||||
os.mkdir("settings")
|
||||
changed = True
|
||||
@@ -3953,7 +3998,7 @@ def get_message(msg):
|
||||
koboldai_vars.cluster_requested_models = msg['online_model']
|
||||
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model'])
|
||||
elif(msg['cmd'] == 'show_model'):
|
||||
print("Model Name: {}".format(getmodelname()))
|
||||
logger.info(f"Model Name: {getmodelname()}")
|
||||
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True, room="UI_1")
|
||||
elif(msg['cmd'] == 'selectmodel'):
|
||||
# This is run when a model line is selected from the UI (line from the model_menu variable) that is tagged as not a menu
|
||||
@@ -4004,14 +4049,14 @@ def get_message(msg):
|
||||
elif(msg['cmd'] == 'delete_model'):
|
||||
if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']):
|
||||
if check_if_dir_is_model(msg['data']):
|
||||
print(colors.YELLOW + "WARNING: Someone deleted " + msg['data'])
|
||||
logger.warning(f"Someone deleted {msg['data']}")
|
||||
import shutil
|
||||
shutil.rmtree(msg['data'])
|
||||
sendModelSelection(menu=msg['menu'])
|
||||
else:
|
||||
print(colors.RED + "ERROR: Someone attempted to delete " + msg['data'] + " but this is not a valid model")
|
||||
logger.error(f"Someone attempted to delete {msg['data']} but this is not a valid model")
|
||||
else:
|
||||
print(colors.RED + "WARNING!!: Someone maliciously attempted to delete " + msg['data'] + " the attempt has been blocked.")
|
||||
logger.critical(f"Someone maliciously attempted to delete {msg['data']}. The attempt has been blocked.")
|
||||
elif(msg['cmd'] == 'OAI_Key_Update'):
|
||||
get_oai_models({'model': koboldai_vars.model, 'key': msg['key']})
|
||||
elif(msg['cmd'] == 'Cluster_Key_Update'):
|
||||
@@ -4250,7 +4295,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=koboldai_vars.revision, cache_dir="cache", use_fast=False)
|
||||
except:
|
||||
print(f"WARNING: Unknown tokenizer {repr(tokenizer_id)}")
|
||||
logger.warning(f"Unknown tokenizer {repr(tokenizer_id)}")
|
||||
koboldai_vars.api_tokenizer_id = tokenizer_id
|
||||
|
||||
if(disable_recentrng):
|
||||
@@ -4403,7 +4448,8 @@ def apiactionsubmit_generate(txt, minimum, maximum):
|
||||
koboldai_vars.generated_tkns = 0
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
|
||||
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if(koboldai_vars.hascuda and (koboldai_vars.usegpu or koboldai_vars.breakmodel)):
|
||||
@@ -4430,7 +4476,8 @@ def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum):
|
||||
tpu_mtj_backend.set_rng_seed(koboldai_vars.seed)
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
|
||||
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
koboldai_vars._prompt = koboldai_vars.prompt
|
||||
|
||||
@@ -4917,7 +4964,8 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
found_entries = tuple(found_entries.copy() for _ in range(koboldai_vars.numseqs))
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
|
||||
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
koboldai_vars.lastctx = utils.decodenewlines(tokenizer.decode(txt))
|
||||
@@ -4936,13 +4984,12 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
koboldai_vars.lua_running = False
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
|
||||
sendUSStatItems()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
else:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occurred during generator call; please check console.'}, broadcast=True, room="UI_1")
|
||||
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||
logger.error(traceback.format_exc().replace("\033", ""))
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
set_aibusy(0)
|
||||
return
|
||||
@@ -4985,7 +5032,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
#==================================================================#
|
||||
def genresult(genout, flash=True, ignore_formatting=False):
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
|
||||
logger.generation(genout.encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
# Format output before continuing
|
||||
if not ignore_formatting:
|
||||
@@ -5015,7 +5062,8 @@ def genselect(genout):
|
||||
# Apply output formatting rules to sequences
|
||||
result["generated_text"] = applyoutputformatting(result["generated_text"])
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}[Result {1}]\n{2}{3}".format(colors.CYAN, i, result["generated_text"], colors.END))
|
||||
logger.info(f"Generation Result {i}")
|
||||
logger.generation(result["generated_text"].encode("unicode_escape").decode("utf-8"))
|
||||
i += 1
|
||||
|
||||
|
||||
@@ -5228,11 +5276,11 @@ def sendtoapi(txt, min, max):
|
||||
def sendtocluster(txt, min, max):
|
||||
# Log request to console
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
|
||||
logger.debug(f"Tokens Min:{min-1}")
|
||||
logger.prompt(txt.encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
koboldai_vars.lastctx = txt
|
||||
|
||||
# Build request JSON data
|
||||
reqdata = {
|
||||
'max_length': max - min + 1,
|
||||
@@ -5254,39 +5302,41 @@ def sendtocluster(txt, min, max):
|
||||
'api_key': koboldai_vars.apikey,
|
||||
'models': koboldai_vars.cluster_requested_models,
|
||||
}
|
||||
logger.debug(f"Horde Payload: {cluster_metadata}")
|
||||
try:
|
||||
# Create request
|
||||
req = requests.post(
|
||||
koboldai_vars.colaburl[:-8] + "/api/v1/generate/sync",
|
||||
json=cluster_metadata,
|
||||
)
|
||||
js = req.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
||||
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
if(req.status_code == 503):
|
||||
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
|
||||
logger.error(req.text)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
if(req.status_code != 200):
|
||||
if(not req.ok):
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
|
||||
logger.error(req.text)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
try:
|
||||
js = req.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
||||
logger.error(errmsg)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||
set_aibusy(0)
|
||||
return
|
||||
gen_servers = [(cgen['server_name'],cgen['server_id']) for cgen in js]
|
||||
print(f"{colors.GREEN}Generations by: {gen_servers}{colors.END}")
|
||||
logger.info(f"Generations by: {gen_servers}")
|
||||
# Just in case we want to announce it to the user
|
||||
if len(js) == 1:
|
||||
warnmsg = f"Text generated by {js[0]['server_name']}"
|
||||
@@ -5336,7 +5386,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
found_entries = tuple(found_entries.copy() for _ in range(koboldai_vars.numseqs))
|
||||
|
||||
if not koboldai_vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
|
||||
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
|
||||
|
||||
koboldai_vars._prompt = koboldai_vars.prompt
|
||||
|
||||
@@ -5425,9 +5476,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
koboldai_vars.lua_running = False
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True, room="UI_1")
|
||||
sendUSStatItems()
|
||||
print("{0}{1}{2}".format(colors.RED, "***LUA ERROR***: ", colors.END), end="", file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
|
||||
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
|
||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||
socketio.emit("error", str(e), broadcast=True, room="UI_2")
|
||||
else:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occurred during generator call; please check console.'}, broadcast=True, room="UI_1")
|
||||
@@ -5704,7 +5754,7 @@ def inlineedit(chunk, data):
|
||||
if(chunk-1 in koboldai_vars.actions):
|
||||
koboldai_vars.actions[chunk-1] = data
|
||||
else:
|
||||
print(f"WARNING: Attempted to edit non-existent chunk {chunk}")
|
||||
logger.warning(f"Attempted to edit non-existent chunk {chunk}")
|
||||
|
||||
setgamesaved(False)
|
||||
update_story_chunk(chunk)
|
||||
@@ -5728,7 +5778,7 @@ def inlinedelete(chunk):
|
||||
if(chunk-1 in koboldai_vars.actions):
|
||||
koboldai_vars.actions.delete_action(chunk-1)
|
||||
else:
|
||||
print(f"WARNING: Attempted to delete non-existent chunk {chunk}")
|
||||
logger.warning(f"Attempted to delete non-existent chunk {chunk}")
|
||||
setgamesaved(False)
|
||||
remove_story_chunk(chunk)
|
||||
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True, room="UI_1")
|
||||
@@ -10847,9 +10897,13 @@ def startup():
|
||||
|
||||
print("", end="", flush=True)
|
||||
if __name__ == "__main__":
|
||||
print("{0}\nStarting webserver...{1}".format(colors.GREEN, colors.END), flush=True)
|
||||
|
||||
general_startup()
|
||||
# Start flask & SocketIO
|
||||
logger.init("Flask", status="Starting")
|
||||
Session(app)
|
||||
logger.init_ok("Flask", status="OK")
|
||||
logger.init("Webserver", status="Starting")
|
||||
patch_transformers()
|
||||
startup()
|
||||
# Start Flask/SocketIO (Blocking, so this must be last method!)
|
||||
@@ -10883,12 +10937,12 @@ if __name__ == "__main__":
|
||||
if(args.localtunnel or args.ngrok or args.remote):
|
||||
with open('cloudflare.log', 'w') as cloudflarelog:
|
||||
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
|
||||
koboldai_vars.cloudflare_link = cloudflare
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 1: " + cloudflare + format(colors.END))
|
||||
print(format(colors.GREEN) + "KoboldAI has finished loading and is available at the following link for UI 2: " + cloudflare + "/new_ui" + format(colors.END))
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
|
||||
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
|
||||
else:
|
||||
print("{0}Webserver has started, you can now connect to this machine at port {1}{2}"
|
||||
.format(colors.GREEN, port, colors.END))
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
|
||||
koboldai_vars.serverstarted = True
|
||||
socketio.run(app, host='0.0.0.0', port=port)
|
||||
else:
|
||||
@@ -10896,8 +10950,8 @@ if __name__ == "__main__":
|
||||
if not args.no_ui:
|
||||
import webbrowser
|
||||
webbrowser.open_new('http://localhost:{0}'.format(port))
|
||||
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
||||
.format(colors.GREEN, port, colors.END))
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"Webserver started! You may now connect with a browser at http://127.0.0.1:{port}")
|
||||
koboldai_vars.serverstarted = True
|
||||
socketio.run(app, port=port, host='0.0.0.0')
|
||||
else:
|
||||
@@ -10910,13 +10964,19 @@ if __name__ == "__main__":
|
||||
if not args.no_ui:
|
||||
import webbrowser
|
||||
webbrowser.open_new('http://localhost:{0}'.format(port))
|
||||
print("{0}Server started!\nYou may now connect with a browser at http://127.0.0.1:{1}/{2}"
|
||||
.format(colors.GREEN, port, colors.END))
|
||||
logger.init_ok("Webserver", status="OK")
|
||||
logger.message(f"Webserver started! You may now connect with a browser at http://127.0.0.1:{port}")
|
||||
koboldai_vars.serverstarted = True
|
||||
socketio.run(app, port=port)
|
||||
logger.init("Webserver", status="Closed")
|
||||
|
||||
|
||||
else:
|
||||
general_startup()
|
||||
# Start flask & SocketIO
|
||||
logger.init("Flask", status="Starting")
|
||||
Session(app)
|
||||
logger.init_ok("Flask", status="OK")
|
||||
patch_transformers()
|
||||
startup()
|
||||
koboldai_settings.port = args.port if "port" in args and args.port is not None else 5000
|
||||
|
@@ -18,6 +18,7 @@ dependencies:
|
||||
- git=2.35.1
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- pip:
|
||||
- git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3-rp-b
|
||||
- flask-cloudflared
|
||||
|
@@ -19,6 +19,7 @@ dependencies:
|
||||
- protobuf
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- pip:
|
||||
- flask-cloudflared
|
||||
- flask-ngrok
|
||||
|
@@ -14,6 +14,7 @@ dependencies:
|
||||
- git=2.35.1
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- pip:
|
||||
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
|
||||
- torch
|
||||
|
@@ -16,6 +16,7 @@ dependencies:
|
||||
- protobuf
|
||||
- marshmallow>=3.13
|
||||
- apispec-webframeworks
|
||||
- loguru
|
||||
- pip:
|
||||
- --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html
|
||||
- torch==1.10.*
|
||||
|
@@ -3,6 +3,7 @@ from typing import Tuple, Union, Optional
|
||||
import os
|
||||
import json
|
||||
import zipfile
|
||||
from logger import logger
|
||||
|
||||
#==================================================================#
|
||||
# Generic Method for prompting for file path
|
||||
@@ -156,16 +157,16 @@ def getspfiles(model_dimension: int):
|
||||
continue
|
||||
z, version, shape, fortran_order, dtype = checksp("./softprompts/"+file, model_dimension)
|
||||
if z == 1:
|
||||
print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
|
||||
logger.warning(f"Softprompt {file} is malformed or not a soft prompt ZIP file.")
|
||||
continue
|
||||
if z == 2:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
|
||||
logger.warning(f"Softprompt {file} tensor.npy has unsupported dtype '{dtype.name}'.")
|
||||
continue
|
||||
if z == 3:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
|
||||
logger.debug(f"Softprompt {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
|
||||
continue
|
||||
if z == 4:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
|
||||
logger.warning(f"Softprompt {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
|
||||
continue
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
try:
|
||||
|
@@ -185,7 +185,7 @@ class koboldai_vars(object):
|
||||
prompt_text = self.tokenizer.decode(self.tokenizer.encode(self.prompt)[-self.max_prompt_length-1:])
|
||||
|
||||
text += prompt_text
|
||||
context.append({"type": "prompt", "text": self.prompt_text})
|
||||
context.append({"type": "prompt", "text": prompt_text})
|
||||
self.prompt_in_ai = True
|
||||
else:
|
||||
self.prompt_in_ai = False
|
||||
|
99
logger.py
Normal file
99
logger.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import sys
|
||||
from functools import partialmethod
|
||||
from loguru import logger
|
||||
|
||||
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
|
||||
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
|
||||
MESSAGE_LEVELS = ["MESSAGE"]
|
||||
# By default we're at error level or higher
|
||||
verbosity = 20
|
||||
quiet = 0
|
||||
|
||||
def set_logger_verbosity(count):
|
||||
global verbosity
|
||||
# The count comes reversed. So count = 0 means minimum verbosity
|
||||
# While count 5 means maximum verbosity
|
||||
# So the more count we have, the lowe we drop the versbosity maximum
|
||||
verbosity = 20 - (count * 10)
|
||||
|
||||
def quiesce_logger(count):
|
||||
global quiet
|
||||
# The bigger the count, the more silent we want our logger
|
||||
quiet = count * 10
|
||||
|
||||
def is_stdout_log(record):
|
||||
if record["level"].name not in STDOUT_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_init_log(record):
|
||||
if record["level"].name not in INIT_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_msg_log(record):
|
||||
if record["level"].name not in MESSAGE_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_stderr_log(record):
|
||||
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def test_logger():
|
||||
logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
|
||||
logger.prompt("This is a prompt message")
|
||||
logger.debug("Debug Message")
|
||||
logger.info("Info Message")
|
||||
logger.warning("Info Warning")
|
||||
logger.error("Error Message")
|
||||
logger.critical("Critical Message")
|
||||
logger.init("This is an init message", status="Starting")
|
||||
logger.init_ok("This is an init message", status="OK")
|
||||
logger.init_warn("This is an init message", status="Warning")
|
||||
logger.init_err("This is an init message", status="Error")
|
||||
logger.message("This is user message")
|
||||
sys.exit()
|
||||
|
||||
|
||||
logfmt = "<level>{level: <10}</level> | <green>{name}</green>:<green>{function}</green>:<green>{line}</green> - <level>{message}</level>"
|
||||
genfmt = "<level>{level: <10}</level> @ <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>"
|
||||
initfmt = "<magenta>INIT </magenta> | <level>{extra[status]: <10}</level> | <magenta>{message}</magenta>"
|
||||
msgfmt = "<level>{level: <10}</level> | <level>{message}</level>"
|
||||
|
||||
logger.level("GENERATION", no=24, color="<cyan>")
|
||||
logger.level("PROMPT", no=23, color="<yellow>")
|
||||
logger.level("INIT", no=31, color="<white>")
|
||||
logger.level("INIT_OK", no=31, color="<green>")
|
||||
logger.level("INIT_WARN", no=31, color="<yellow>")
|
||||
logger.level("INIT_ERR", no=31, color="<red>")
|
||||
# Messages contain important information without which this application might not be able to be used
|
||||
# As such, they have the highest priority
|
||||
logger.level("MESSAGE", no=61, color="<green>")
|
||||
|
||||
logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
|
||||
logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
|
||||
logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
|
||||
logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
|
||||
logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
|
||||
logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
|
||||
logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
|
||||
|
||||
config = {
|
||||
"handlers": [
|
||||
{"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
|
||||
{"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
|
||||
{"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
|
||||
{"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
|
||||
],
|
||||
}
|
||||
logger.configure(**config)
|
@@ -14,4 +14,5 @@ protobuf
|
||||
accelerate
|
||||
flask_session
|
||||
marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
apispec-webframeworks
|
||||
loguru
|
@@ -2,11 +2,10 @@ torch >= 1.9, <= 1.11
|
||||
numpy
|
||||
tqdm
|
||||
requests
|
||||
optax >= 0.0.5, <= 0.0.9
|
||||
dm-haiku == 0.0.5
|
||||
jax == 0.2.21
|
||||
jaxlib >= 0.1.69, <= 0.3.7
|
||||
transformers >= 4.19
|
||||
transformers ==4.21.3
|
||||
progressbar2
|
||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||
flask
|
||||
@@ -17,7 +16,7 @@ eventlet
|
||||
lupa==1.10
|
||||
markdown
|
||||
bleach==4.1.0
|
||||
chex==0.1.4
|
||||
flask-session
|
||||
marshmallow>=3.13
|
||||
apispec-webframeworks
|
||||
apispec-webframeworks
|
||||
loguru
|
@@ -2913,6 +2913,7 @@ $(document).ready(function(){
|
||||
$("#oaimodel").addClass("hidden")
|
||||
buildLoadModelList(msg.data, msg.menu, msg.breadcrumbs, msg.showdelete);
|
||||
} else if(msg.cmd == 'selected_model_info') {
|
||||
console.log(msg);
|
||||
enableButtons([load_model_accept]);
|
||||
$("#oaimodel").addClass("hidden")
|
||||
$("#oaimodel")[0].options[0].selected = true;
|
||||
@@ -2946,7 +2947,7 @@ $(document).ready(function(){
|
||||
if (msg.url) {
|
||||
$("#modelurl").removeClass("hidden");
|
||||
if (msg.default_url != null) {
|
||||
$("#modelurl").value = msg.default_url;
|
||||
document.getElementById("modelurl").value = msg.default_url;
|
||||
}
|
||||
} else {
|
||||
$("#modelurl").addClass("hidden");
|
||||
|
@@ -2355,4 +2355,19 @@ h2 .material-icons-outlined {
|
||||
.horde_trigger[model_model="ReadOnly"],
|
||||
.horde_trigger[model_model="CLUSTER"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.preset_area {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.preset_area .settings_button {
|
||||
transform: translateY(6px);
|
||||
width: 155px;
|
||||
}
|
||||
|
||||
input[type='range'] {
|
||||
border: none !important;
|
||||
}
|
@@ -107,12 +107,12 @@
|
||||
<span id="debug-dump" class="cursor" onclick="document.getElementById('debug-file-container').classList.remove('hidden');">Download debug dump</span>
|
||||
</div>
|
||||
<div id="setting_menu_settings" class="hidden settings_category_area tab-target tab-target-settings">
|
||||
<div class="force_center">
|
||||
<select class="var_sync_model_selected_preset settings_select presets" onchange='sync_to_server(this)'><option>Preset</option></select>
|
||||
<div class="preset_area">
|
||||
<button class="settings_button" onclick="show_save_preset();">
|
||||
<span class="material-icons-outlined cursor" title="Save Preset">save</span>
|
||||
<span class="button_label">Save Preset</span>
|
||||
</button>
|
||||
<select class="var_sync_model_selected_preset settings_select presets" onchange='sync_to_server(this)'><option>Preset</option></select>
|
||||
</div>
|
||||
{% with menu='Settings' %}
|
||||
<div class="collapsable_header" onclick="toggle_setting_category(this);">
|
||||
|
@@ -30,7 +30,7 @@ SOFTWARE.
|
||||
import utils
|
||||
|
||||
import multiprocessing
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
import time
|
||||
import os
|
||||
@@ -45,7 +45,6 @@ from jax.config import config
|
||||
from jax.experimental import maps
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import haiku as hk
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from tokenizers import Tokenizer
|
||||
@@ -136,6 +135,14 @@ def __batch_xmap(shard_dim=1):
|
||||
return inner
|
||||
|
||||
|
||||
class _EmptyState(NamedTuple):
|
||||
pass
|
||||
|
||||
class _DummyOptimizer:
|
||||
def init(*args, **kwargs):
|
||||
return _EmptyState()
|
||||
|
||||
|
||||
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
@@ -1167,7 +1174,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
seq = params["seq"]
|
||||
params["optimizer"] = optax.scale(0)
|
||||
params["optimizer"] = _DummyOptimizer()
|
||||
mesh_shape = (1, cores_per_replica)
|
||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
|
375
utils.py
375
utils.py
@@ -4,6 +4,7 @@ import shutil
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from urllib.error import HTTPError
|
||||
import requests
|
||||
import requests.adapters
|
||||
import time
|
||||
@@ -13,6 +14,10 @@ import packaging.version
|
||||
from tqdm.auto import tqdm
|
||||
import os
|
||||
import itertools
|
||||
import hashlib
|
||||
import huggingface_hub
|
||||
import packaging.version
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
HAS_ACCELERATE = packaging.version.parse(transformers_version) >= packaging.version.parse("4.20.0.dev0")
|
||||
@@ -182,81 +187,9 @@ class Send_to_socketio(object):
|
||||
except:
|
||||
pass
|
||||
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
def _download_with_aria2(aria2_config: str, total_length: int, directory: str = ".", user_agent=None, force_download=False, use_auth_token=None):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from huggingface_hub import HfFolder
|
||||
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||
return
|
||||
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
||||
return
|
||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
|
||||
return
|
||||
if proxies:
|
||||
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
||||
return
|
||||
if use_auth_token:
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
else:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||
sharded = False
|
||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||
if use_auth_token:
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
def is_cached(url):
|
||||
try:
|
||||
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
|
||||
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
|
||||
return False
|
||||
return True
|
||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||
try:
|
||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||
except AttributeError:
|
||||
return
|
||||
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
|
||||
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
break
|
||||
if sharded:
|
||||
return
|
||||
else:
|
||||
sharded = True
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u in urls if not is_cached(u)]
|
||||
if not urls:
|
||||
return
|
||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
|
||||
for n in filenames:
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
if force_download:
|
||||
path = os.path.join(_cache_dir, n + ".json")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(_cache_dir, n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||
lengths = {}
|
||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||
s = requests.Session()
|
||||
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
|
||||
bar = None
|
||||
@@ -266,7 +199,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
|
||||
f.write(aria2_config)
|
||||
f.flush()
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(koboldai_vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(koboldai_vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", directory, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {use_auth_token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
while p.poll() is None:
|
||||
r = s.post(f"http://localhost:{koboldai_vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
|
||||
if not r:
|
||||
@@ -278,7 +211,7 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
done = True
|
||||
break
|
||||
if bar is None:
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000, file=Send_to_socketio())
|
||||
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
|
||||
visited = set()
|
||||
for x in r:
|
||||
filename = x["files"][0]["path"]
|
||||
@@ -296,12 +229,298 @@ def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_d
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
os.remove(path)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
code = p.wait()
|
||||
if not done and code:
|
||||
raise OSError(f"aria2 exited with exit code {code}")
|
||||
|
||||
def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from huggingface_hub import HfFolder
|
||||
if use_auth_token:
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
else:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||
_revision = revision if revision is not None else huggingface_hub.constants.DEFAULT_REVISION
|
||||
sharded = False
|
||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||
if use_auth_token:
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
|
||||
storage_folder = os.path.join(_cache_dir, huggingface_hub.file_download.repo_folder_name(repo_id=pretrained_model_name_or_path, repo_type="model"))
|
||||
os.makedirs(storage_folder, exist_ok=True)
|
||||
|
||||
def is_cached(filename):
|
||||
try:
|
||||
huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, local_files_only=True)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||
try:
|
||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||
except AttributeError:
|
||||
return
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||
if is_cached(filename) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
break
|
||||
if sharded:
|
||||
return
|
||||
else:
|
||||
sharded = True
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = huggingface_hub.hf_hub_download(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u, n in zip(urls, filenames) if not is_cached(n)]
|
||||
if not urls:
|
||||
return
|
||||
|
||||
blob_paths = []
|
||||
|
||||
# This section is a modified version of hf_hub_download from huggingface_hub
|
||||
# See https://github.com/huggingface/huggingface_hub/blob/main/LICENSE for license
|
||||
for u, n in zip(urls, filenames):
|
||||
relative_filename = os.path.join(*n.split("/"))
|
||||
if not local_files_only:
|
||||
try:
|
||||
r = huggingface_hub.file_download._request_wrapper(
|
||||
method="HEAD",
|
||||
url=u,
|
||||
headers=headers,
|
||||
allow_redirects=False,
|
||||
follow_relative_redirects=True,
|
||||
proxies=proxies,
|
||||
timeout=10,
|
||||
)
|
||||
try:
|
||||
r.raise_for_status()
|
||||
except HTTPError as e:
|
||||
error_code = r.headers.get("X-Error-Code")
|
||||
if error_code != "EntryNotFound":
|
||||
raise RuntimeError(f"HEAD {u} failed with error code {r.status_code}")
|
||||
commit_hash = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
||||
if commit_hash is not None:
|
||||
no_exist_file_path = (
|
||||
Path(storage_folder)
|
||||
/ ".no_exist"
|
||||
/ commit_hash
|
||||
/ relative_filename
|
||||
)
|
||||
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
no_exist_file_path.touch()
|
||||
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(
|
||||
storage_folder, _revision, commit_hash
|
||||
)
|
||||
raise
|
||||
commit_hash = r.headers[huggingface_hub.file_download.HUGGINGFACE_HEADER_X_REPO_COMMIT]
|
||||
if commit_hash is None:
|
||||
raise OSError(
|
||||
"Distant resource does not seem to be on huggingface.co (missing"
|
||||
" commit header)."
|
||||
)
|
||||
etag = r.headers.get(huggingface_hub.file_download.HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get(
|
||||
"ETag"
|
||||
)
|
||||
# We favor a custom header indicating the etag of the linked resource, and
|
||||
# we fallback to the regular etag header.
|
||||
# If we don't have any of those, raise an error.
|
||||
if etag is None:
|
||||
raise OSError(
|
||||
"Distant resource does not have an ETag, we won't be able to"
|
||||
" reliably ensure reproducibility."
|
||||
)
|
||||
etag = huggingface_hub.file_download._normalize_etag(etag)
|
||||
# In case of a redirect, save an extra redirect on the request.get call,
|
||||
# and ensure we download the exact atomic version even if it changed
|
||||
# between the HEAD and the GET (unlikely, but hey).
|
||||
# Useful for lfs blobs that are stored on a CDN.
|
||||
if 300 <= r.status_code <= 399:
|
||||
url_to_download = r.headers["Location"]
|
||||
if (
|
||||
"lfs.huggingface.co" in url_to_download
|
||||
or "lfs-staging.huggingface.co" in url_to_download
|
||||
):
|
||||
# Remove authorization header when downloading a LFS blob
|
||||
headers.pop("authorization", None)
|
||||
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
||||
# Actually raise for those subclasses of ConnectionError
|
||||
raise
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
huggingface_hub.file_download.OfflineModeIsEnabled,
|
||||
):
|
||||
# Otherwise, our Internet connection is down.
|
||||
# etag is None
|
||||
pass
|
||||
if etag is None:
|
||||
# In those cases, we cannot force download.
|
||||
if force_download:
|
||||
raise ValueError(
|
||||
"We have no connection or you passed local_files_only, so"
|
||||
" force_download is not an accepted option."
|
||||
)
|
||||
if huggingface_hub.file_download.REGEX_COMMIT_HASH.match(_revision):
|
||||
commit_hash = _revision
|
||||
else:
|
||||
ref_path = os.path.join(storage_folder, "refs", _revision)
|
||||
with open(ref_path) as f:
|
||||
commit_hash = f.read()
|
||||
pointer_path = os.path.join(
|
||||
storage_folder, "snapshots", commit_hash, relative_filename
|
||||
)
|
||||
if os.path.exists(pointer_path):
|
||||
return pointer_path
|
||||
# If we couldn't find an appropriate file on disk,
|
||||
# raise an error.
|
||||
# If files cannot be found and local_files_only=True,
|
||||
# the models might've been found if local_files_only=False
|
||||
# Notify the user about that
|
||||
if local_files_only:
|
||||
raise huggingface_hub.file_download.LocalEntryNotFoundError(
|
||||
"Cannot find the requested files in the disk cache and"
|
||||
" outgoing traffic has been disabled. To enable hf.co look-ups"
|
||||
" and downloads online, set 'local_files_only' to False."
|
||||
)
|
||||
else:
|
||||
raise huggingface_hub.file_download.LocalEntryNotFoundError(
|
||||
"Connection error, and we cannot find the requested files in"
|
||||
" the disk cache. Please try again or make sure your Internet"
|
||||
" connection is on."
|
||||
)
|
||||
# From now on, etag and commit_hash are not None.
|
||||
blob_path = os.path.join(storage_folder, "blobs", etag)
|
||||
pointer_path = os.path.join(
|
||||
storage_folder, "snapshots", commit_hash, relative_filename
|
||||
)
|
||||
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
|
||||
# if passed revision is not identical to commit_hash
|
||||
# then revision has to be a branch name or tag name.
|
||||
# In that case store a ref.
|
||||
huggingface_hub.file_download._cache_commit_hash_for_specific_revision(storage_folder, _revision, commit_hash)
|
||||
if os.path.exists(pointer_path) and not force_download:
|
||||
return pointer_path
|
||||
if os.path.exists(blob_path) and not force_download:
|
||||
# we have the blob already, but not the pointer
|
||||
huggingface_hub.file_download.logger.info("creating pointer to %s from %s", blob_path, pointer_path)
|
||||
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
|
||||
return pointer_path
|
||||
# Some Windows versions do not allow for paths longer than 255 characters.
|
||||
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
|
||||
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
|
||||
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
|
||||
blob_paths.append(blob_path)
|
||||
|
||||
filenames = blob_paths
|
||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||
|
||||
for n in filenames:
|
||||
prefix, suffix = n.rsplit("/", 1)
|
||||
path = os.path.join(prefix, "kai-tempfile." + suffix + ".aria2")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(prefix, "kai-tempfile." + suffix)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||
aria2_config = "\n".join(f"{u}\n out={os.path.join(prefix, 'kai-tempfile.' + suffix)}" for u, n in zip(urls, filenames) for prefix, suffix in [n.rsplit("/", 1)]).encode()
|
||||
_download_with_aria2(aria2_config, total_length, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
||||
for u, n in zip(urls, filenames):
|
||||
prefix, suffix = n.rsplit("/", 1)
|
||||
os.rename(os.path.join(prefix, "kai-tempfile." + suffix), os.path.join(prefix, suffix))
|
||||
|
||||
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from huggingface_hub import HfFolder
|
||||
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
|
||||
return
|
||||
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
|
||||
return
|
||||
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
|
||||
return
|
||||
if proxies:
|
||||
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
|
||||
return
|
||||
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.22.0.dev0"):
|
||||
return _transformers22_aria2_hook(pretrained_model_name_or_path, force_download=force_download, cache_dir=cache_dir, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, **kwargs)
|
||||
if use_auth_token:
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
else:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
|
||||
sharded = False
|
||||
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
|
||||
if use_auth_token:
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
def is_cached(url):
|
||||
try:
|
||||
huggingface_hub.cached_download(url, cache_dir=cache_dir, local_files_only=True)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
|
||||
try:
|
||||
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
|
||||
except AttributeError:
|
||||
return
|
||||
url = huggingface_hub.hf_hub_url(pretrained_model_name_or_path, filename, revision=revision)
|
||||
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
|
||||
break
|
||||
if sharded:
|
||||
return
|
||||
else:
|
||||
sharded = True
|
||||
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
|
||||
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
|
||||
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
|
||||
map_filename = huggingface_hub.cached_download(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
|
||||
with open(map_filename) as f:
|
||||
map_data = json.load(f)
|
||||
filenames = set(map_data["weight_map"].values())
|
||||
urls = [huggingface_hub.hf_hub_url(pretrained_model_name_or_path, n, revision=revision) for n in filenames]
|
||||
if not force_download:
|
||||
urls = [u for u in urls if not is_cached(u)]
|
||||
if not urls:
|
||||
return
|
||||
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
|
||||
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
|
||||
filenames = [hashlib.sha256(u.encode("utf-8")).hexdigest() + "." + hashlib.sha256(t.encode("utf-8")).hexdigest() for u, t in zip(urls, etags)]
|
||||
for n in filenames:
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(_cache_dir, "kai-tempfile." + n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
if force_download:
|
||||
path = os.path.join(_cache_dir, n + ".json")
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
path = os.path.join(_cache_dir, n)
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
total_length = sum(int(h["Content-Length"]) for h in headers)
|
||||
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
|
||||
_download_with_aria2(aria2_config, total_length, directory=_cache_dir, use_auth_token=token if use_auth_token else None, user_agent=user_agent, force_download=force_download)
|
||||
for u, t, n in zip(urls, etags, filenames):
|
||||
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
|
||||
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
|
||||
@@ -321,10 +540,10 @@ def get_num_shards(filename):
|
||||
# pytorch_model.bin.index.json, returns a list of weight names in the
|
||||
# sharded model. Requires lazy loader to be enabled to work properl
|
||||
#==================================================================#
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
|
||||
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, **kwargs):
|
||||
import transformers.modeling_utils
|
||||
import torch
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
|
||||
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision)
|
||||
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))
|
||||
|
||||
#==================================================================#
|
||||
|
Reference in New Issue
Block a user