diff --git a/aiserver.py b/aiserver.py index 5002feb5..ef785313 100644 --- a/aiserver.py +++ b/aiserver.py @@ -239,6 +239,7 @@ class vars: submission = "" # Same as above, but after applying input formatting lastctx = "" # The last context submitted to the generator model = "" # Model ID string chosen at startup + model_selected = "" #selected model in UI model_type = "" # Model Type (Automatically taken from the model config) noai = False # Runs the script without starting up the transformers pipeline aibusy = False # Stops submissions while the AI is working @@ -1479,11 +1480,11 @@ def get_layer_count(model, directory=""): else: from transformers import AutoConfig if directory == "": - model_config = AutoConfig.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") - elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): - model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") + model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache") elif(os.path.isdir(directory)): model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") + elif(os.path.isdir(vars.custmodpth.replace('/', '_'))): + model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache") else: model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") @@ -1496,9 +1497,9 @@ def get_layer_count(model, directory=""): def get_oai_models(key): vars.oaiapikey = key - if vars.model == 'OAI': + if vars.model_selected == 'OAI': url = "https://api.openai.com/v1/engines" - elif vars.model == 'GooseAI': + elif vars.model_selected == 'GooseAI': url = "https://api.goose.ai/v1/engines" else: return @@ -1527,8 +1528,8 @@ def get_oai_models(key): # If the client settings file doesn't exist, create it # Write API key to file os.makedirs('settings', exist_ok=True) - if path.exists("settings/{}.settings".format(vars.model)): - with open("settings/{}.settings".format(vars.model), "r") as file: + if path.exists("settings/{}.settings".format(vars.model_selected)): + with open("settings/{}.settings".format(vars.model_selected), "r") as file: js = json.load(file) if 'online_model' in js: online_model = js['online_model'] @@ -1536,7 +1537,7 @@ def get_oai_models(key): if js['apikey'] != key: changed=True if changed: - with open("settings/{}.settings".format(vars.model), "w") as file: + with open("settings/{}.settings".format(vars.model_selected), "w") as file: js["apikey"] = key file.write(json.dumps(js, indent=3)) @@ -3674,8 +3675,8 @@ def get_message(msg): changed = True if not utils.HAS_ACCELERATE: msg['disk_layers'] = "0" - if os.path.exists("settings/" + vars.model.replace('/', '_') + ".breakmodel"): - with open("settings/" + vars.model.replace('/', '_') + ".breakmodel", "r") as file: + if os.path.exists("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel"): + with open("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel", "r") as file: data = file.read().split('\n')[:2] if len(data) < 2: data.append("0") @@ -3683,14 +3684,15 @@ def get_message(msg): if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']: changed = False if changed: - if vars.model in ["NeoCustom", "GPT2Custom"]: + if vars.model_selected in ["NeoCustom", "GPT2Custom"]: filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth))) else: - filename = "settings/{}.breakmodel".format(vars.model.replace('/', '_')) + filename = "settings/{}.breakmodel".format(vars.model_selected.replace('/', '_')) f = open(filename, "w") f.write(str(msg['gpu_layers']) + '\n' + str(msg['disk_layers'])) f.close() vars.colaburl = msg['url'] + "/request" + vars.model = vars.model_selected load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model']) elif(msg['cmd'] == 'show_model'): print("Model Name: {}".format(getmodelname())) @@ -3715,18 +3717,18 @@ def get_message(msg): elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg: #Here the user entered custom text in the text box. This could be either a model name or a path. if check_if_dir_is_model(msg['path_modelname']): - vars.model = msg['data'] + vars.model_selected = msg['data'] vars.custmodpth = msg['path_modelname'] get_model_info(msg['data'], directory=msg['path']) else: - vars.model = msg['path_modelname'] + vars.model_selected = msg['path_modelname'] try: - get_model_info(vars.model) + get_model_info(vars.model_selected) except: emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."}) elif msg['data'] in ('NeoCustom', 'GPT2Custom'): if check_if_dir_is_model(msg['path']): - vars.model = msg['data'] + vars.model_selected = msg['data'] vars.custmodpth = msg['path'] get_model_info(msg['data'], directory=msg['path']) else: @@ -3735,12 +3737,12 @@ def get_message(msg): else: sendModelSelection(menu=msg['data'], folder=msg['path']) else: - vars.model = msg['data'] + vars.model_selected = msg['data'] if 'path' in msg: vars.custmodpth = msg['path'] get_model_info(msg['data'], directory=msg['path']) else: - get_model_info(vars.model) + get_model_info(vars.model_selected) elif(msg['cmd'] == 'delete_model'): if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']): if check_if_dir_is_model(msg['data']):