From c5ee5d3ea2ed64d2f1eed89e14fd31ec4ac71dce Mon Sep 17 00:00:00 2001 From: Divided by Zer0 Date: Sat, 3 Sep 2022 19:36:06 +0200 Subject: [PATCH] Fixes Horde not saving as expected Now Horde will save different settings per model, or for All Refactored the code so that args.configname is not used like a global var. Added var.online_model because we need to keep track of it --- aiserver.py | 98 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 61 insertions(+), 37 deletions(-) diff --git a/aiserver.py b/aiserver.py index 67aeedd5..aac7be06 100644 --- a/aiserver.py +++ b/aiserver.py @@ -239,7 +239,8 @@ class vars: lastact = "" # The last action received from the user submission = "" # Same as above, but after applying input formatting lastctx = "" # The last context submitted to the generator - model = "" # Model ID string chosen at startup + model = "ReadOnly" # Model ID string chosen at startup + online_model = "" # Used when Model ID is an online service, and there is a secondary option for the actual model name model_selected = "" #selected model in UI model_type = "" # Model Type (Automatically taken from the model config) noai = False # Runs the script without starting up the transformers pipeline @@ -380,6 +381,7 @@ class vars: output_streaming = True token_stream_queue = TokenStreamQueue() # Queue for the token streaming show_probs = False # Whether or not to show token probabilities + configname = None utils.vars = vars @@ -615,6 +617,18 @@ api_v1 = KoboldAPISpec( tags=tags, ) +# Returns the expected config filename for the current setup. +# If the model_name is specified, it returns what the settings file would be for that model +def get_config_filename(model_name = None): + if model_name: + return(f"settings/{model_name.replace('/', '_')}.settings") + elif args.configname: + return(f"settings/{args.configname}.settings") + elif vars.configname != '': + return(f"settings/{vars.configname}.settings") + else: + print(f"Empty configfile name sent back. Defaulting to ReadOnly") + return(f"settings/ReadOnly.settings") #==================================================================# # Function to get model selection at startup #==================================================================# @@ -722,9 +736,8 @@ def check_if_dir_is_model(path): # Return Model Name #==================================================================# def getmodelname(): - if(args.configname): - modelname = args.configname - return modelname + if(vars.online_model != ''): + return(f"{vars.model}/{vars.online_model}") if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): modelname = os.path.basename(os.path.normpath(vars.custmodpth)) return modelname @@ -1058,7 +1071,7 @@ def savesettings(): # Write it if not os.path.exists('settings'): os.mkdir('settings') - file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w") + file = open(get_config_filename(), "w") try: file.write(json.dumps(js, indent=3)) finally: @@ -1084,9 +1097,9 @@ def loadsettings(): processsettings(js) file.close() - if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): + if(path.exists(get_config_filename())): # Read file contents into JSON object - file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r") + file = open(get_config_filename(), "r") js = json.load(file) processsettings(js) @@ -1444,8 +1457,8 @@ def get_model_info(model, directory=""): url = True key = True multi_online_models = True - if path.exists("settings/{}.settings".format(model)): - with open("settings/{}.settings".format(model), "r") as file: + if path.exists(get_config_filename(model)): + with open(get_config_filename(model), "r") as file: # Check if API key exists js = json.load(file) if("apikey" in js and js["apikey"] != ""): @@ -1454,8 +1467,8 @@ def get_model_info(model, directory=""): elif 'oaiapikey' in js and js['oaiapikey'] != "": key_value = js["oaiapikey"] elif model in [x[1] for x in model_menu['apilist']]: - if path.exists("settings/{}.settings".format(model)): - with open("settings/{}.settings".format(model), "r") as file: + if path.exists(get_config_filename(model)): + with open(get_config_filename(model), "r") as file: # Check if API key exists js = json.load(file) if("apikey" in js and js["apikey"] != ""): @@ -1559,8 +1572,8 @@ def get_oai_models(key): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists("settings/{}.settings".format(vars.model_selected)): - with open("settings/{}.settings".format(vars.model_selected), "r") as file: + if path.exists(get_config_filename(vars.model_selected)): + with open(get_config_filename(vars.model_selected), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1571,7 +1584,7 @@ def get_oai_models(key): changed=True if changed: js={} - with open("settings/{}.settings".format(vars.model_selected), "w") as file: + with open(get_config_filename(vars.model_selected), "w") as file: js["apikey"] = key file.write(json.dumps(js, indent=3)) @@ -1609,8 +1622,8 @@ def get_cluster_models(msg): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists("settings/{}.settings".format(vars.model_selected)): - with open("settings/{}.settings".format(vars.model_selected), "r") as file: + if path.exists(get_config_filename(vars.model_selected)): + with open(get_config_filename(vars.model_selected), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1621,7 +1634,7 @@ def get_cluster_models(msg): changed=True if changed: js={} - with open("settings/{}.settings".format(vars.model_selected), "w") as file: + with open(get_config_filename(vars.model_selected), "w") as file: js["apikey"] = vars.oaiapikey file.write(json.dumps(js, indent=3)) @@ -2065,6 +2078,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal model = None generator = None model_config = None + vars.online_model = '' with torch.no_grad(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated") @@ -2083,11 +2097,26 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal #Reload our badwords vars.badwordsids = vars.badwordsids_default + if online_model == "": + vars.configname = vars.model #Let's set the GooseAI or OpenAI server URLs if that's applicable - if online_model != "": - if path.exists("settings/{}.settings".format(vars.model)): + else: + vars.online_model = online_model + # Swap OAI Server if GooseAI was selected + if(vars.model == "GooseAI"): + vars.oaiengines = "https://api.goose.ai/v1/engines" + vars.model = "OAI" + vars.configname = f"GooseAI_{online_model.replace('/', '_')}" + elif(vars.model == "CLUSTER") and type(online_model) is list: + if len(online_model) != 1: + vars.configname = vars.model + else: + vars.configname = f"{vars.model}_{online_model[0].replace('/', '_')}" + else: + vars.configname = f"{vars.model}_{online_model.replace('/', '_')}" + if path.exists(get_config_filename()): changed=False - with open("settings/{}.settings".format(vars.model), "r") as file: + with open(get_config_filename(), "r") as file: # Check if API key exists js = json.load(file) if 'online_model' in js: @@ -2098,15 +2127,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal changed=True js['online_model'] = online_model if changed: - with open("settings/{}.settings".format(vars.model), "w") as file: + with open(get_config_filename(), "w") as file: file.write(json.dumps(js, indent=3)) - # Swap OAI Server if GooseAI was selected - if(vars.model == "GooseAI"): - vars.oaiengines = "https://api.goose.ai/v1/engines" - vars.model = "OAI" - args.configname = "GooseAI" + "/" + online_model - else: - args.configname = vars.model + "/" + online_model vars.oaiurl = vars.oaiengines + "/{0}/completions".format(online_model) @@ -2193,12 +2215,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal if(vars.model == "GooseAI"): vars.oaiengines = "https://api.goose.ai/v1/engines" vars.model = "OAI" - args.configname = "GooseAI" + vars.configname = "GooseAI" # Ask for API key if OpenAI was selected if(vars.model == "OAI"): - if not args.configname: - args.configname = "OAI" + if not vars.configname: + vars.configname = "OAI" if(vars.model == "ReadOnly"): vars.noai = True @@ -2784,8 +2806,8 @@ def lua_startup(): global _bridged global F global bridged - if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): - file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r") + if(path.exists(get_config_filename())): + file = open(get_config_filename(), "r") js = json.load(file) if("userscripts" in js): vars.userscripts = [] @@ -3845,7 +3867,7 @@ def get_message(msg): else: sendModelSelection(menu=msg['data'], folder=msg['path']) else: - vars.model_selected = msg['data'] + vars.model_selected = msg['data'] if 'path' in msg: vars.custmodpth = msg['path'] get_model_info(msg['data'], directory=msg['path']) @@ -6052,7 +6074,9 @@ def oairequest(txt, min, max): vars.lastctx = txt # Build request JSON data - if 'GooseAI' in args.configname: + # GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround + # as the vars.model will always be OAI + if 'GooseAI' in vars.configname: reqdata = { 'prompt': txt, 'max_tokens': vars.genamt, @@ -6882,8 +6906,8 @@ def final_startup(): threading.Thread(target=__preempt_tokenizer).start() # Load soft prompt specified by the settings file, if applicable - if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")): - file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r") + if(path.exists(get_config_filename())): + file = open(get_config_filename(), "r") js = json.load(file) if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))): spRequest(js["softprompt"])