Merge pull request #184 from ebolam/united

Fix for vars.model getting set on AI selection in the UI rather than when actually loaded
This commit is contained in:
henk717 2022-08-18 00:05:55 +02:00 committed by GitHub
commit a3862946aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 18 deletions

View File

@ -239,6 +239,7 @@ class vars:
submission = "" # Same as above, but after applying input formatting submission = "" # Same as above, but after applying input formatting
lastctx = "" # The last context submitted to the generator lastctx = "" # The last context submitted to the generator
model = "" # Model ID string chosen at startup model = "" # Model ID string chosen at startup
model_selected = "" #selected model in UI
model_type = "" # Model Type (Automatically taken from the model config) model_type = "" # Model Type (Automatically taken from the model config)
noai = False # Runs the script without starting up the transformers pipeline noai = False # Runs the script without starting up the transformers pipeline
aibusy = False # Stops submissions while the AI is working aibusy = False # Stops submissions while the AI is working
@ -1479,11 +1480,11 @@ def get_layer_count(model, directory=""):
else: else:
from transformers import AutoConfig from transformers import AutoConfig
if directory == "": if directory == "":
model_config = AutoConfig.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") model_config = AutoConfig.from_pretrained(model, revision=vars.revision, cache_dir="cache")
elif(os.path.isdir(vars.custmodpth.replace('/', '_'))):
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache")
elif(os.path.isdir(directory)): elif(os.path.isdir(directory)):
model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache") model_config = AutoConfig.from_pretrained(directory, revision=vars.revision, cache_dir="cache")
elif(os.path.isdir(vars.custmodpth.replace('/', '_'))):
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=vars.revision, cache_dir="cache")
else: else:
model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
@ -1496,9 +1497,9 @@ def get_layer_count(model, directory=""):
def get_oai_models(key): def get_oai_models(key):
vars.oaiapikey = key vars.oaiapikey = key
if vars.model == 'OAI': if vars.model_selected == 'OAI':
url = "https://api.openai.com/v1/engines" url = "https://api.openai.com/v1/engines"
elif vars.model == 'GooseAI': elif vars.model_selected == 'GooseAI':
url = "https://api.goose.ai/v1/engines" url = "https://api.goose.ai/v1/engines"
else: else:
return return
@ -1527,8 +1528,8 @@ def get_oai_models(key):
# If the client settings file doesn't exist, create it # If the client settings file doesn't exist, create it
# Write API key to file # Write API key to file
os.makedirs('settings', exist_ok=True) os.makedirs('settings', exist_ok=True)
if path.exists("settings/{}.settings".format(vars.model)): if path.exists("settings/{}.settings".format(vars.model_selected)):
with open("settings/{}.settings".format(vars.model), "r") as file: with open("settings/{}.settings".format(vars.model_selected), "r") as file:
js = json.load(file) js = json.load(file)
if 'online_model' in js: if 'online_model' in js:
online_model = js['online_model'] online_model = js['online_model']
@ -1536,7 +1537,7 @@ def get_oai_models(key):
if js['apikey'] != key: if js['apikey'] != key:
changed=True changed=True
if changed: if changed:
with open("settings/{}.settings".format(vars.model), "w") as file: with open("settings/{}.settings".format(vars.model_selected), "w") as file:
js["apikey"] = key js["apikey"] = key
file.write(json.dumps(js, indent=3)) file.write(json.dumps(js, indent=3))
@ -3674,8 +3675,8 @@ def get_message(msg):
changed = True changed = True
if not utils.HAS_ACCELERATE: if not utils.HAS_ACCELERATE:
msg['disk_layers'] = "0" msg['disk_layers'] = "0"
if os.path.exists("settings/" + vars.model.replace('/', '_') + ".breakmodel"): if os.path.exists("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel"):
with open("settings/" + vars.model.replace('/', '_') + ".breakmodel", "r") as file: with open("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel", "r") as file:
data = file.read().split('\n')[:2] data = file.read().split('\n')[:2]
if len(data) < 2: if len(data) < 2:
data.append("0") data.append("0")
@ -3683,14 +3684,15 @@ def get_message(msg):
if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']: if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']:
changed = False changed = False
if changed: if changed:
if vars.model in ["NeoCustom", "GPT2Custom"]: if vars.model_selected in ["NeoCustom", "GPT2Custom"]:
filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth))) filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth)))
else: else:
filename = "settings/{}.breakmodel".format(vars.model.replace('/', '_')) filename = "settings/{}.breakmodel".format(vars.model_selected.replace('/', '_'))
f = open(filename, "w") f = open(filename, "w")
f.write(str(msg['gpu_layers']) + '\n' + str(msg['disk_layers'])) f.write(str(msg['gpu_layers']) + '\n' + str(msg['disk_layers']))
f.close() f.close()
vars.colaburl = msg['url'] + "/request" vars.colaburl = msg['url'] + "/request"
vars.model = vars.model_selected
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model']) load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model'])
elif(msg['cmd'] == 'show_model'): elif(msg['cmd'] == 'show_model'):
print("Model Name: {}".format(getmodelname())) print("Model Name: {}".format(getmodelname()))
@ -3715,18 +3717,18 @@ def get_message(msg):
elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg: elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg:
#Here the user entered custom text in the text box. This could be either a model name or a path. #Here the user entered custom text in the text box. This could be either a model name or a path.
if check_if_dir_is_model(msg['path_modelname']): if check_if_dir_is_model(msg['path_modelname']):
vars.model = msg['data'] vars.model_selected = msg['data']
vars.custmodpth = msg['path_modelname'] vars.custmodpth = msg['path_modelname']
get_model_info(msg['data'], directory=msg['path']) get_model_info(msg['data'], directory=msg['path'])
else: else:
vars.model = msg['path_modelname'] vars.model_selected = msg['path_modelname']
try: try:
get_model_info(vars.model) get_model_info(vars.model_selected)
except: except:
emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."}) emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."})
elif msg['data'] in ('NeoCustom', 'GPT2Custom'): elif msg['data'] in ('NeoCustom', 'GPT2Custom'):
if check_if_dir_is_model(msg['path']): if check_if_dir_is_model(msg['path']):
vars.model = msg['data'] vars.model_selected = msg['data']
vars.custmodpth = msg['path'] vars.custmodpth = msg['path']
get_model_info(msg['data'], directory=msg['path']) get_model_info(msg['data'], directory=msg['path'])
else: else:
@ -3735,12 +3737,12 @@ def get_message(msg):
else: else:
sendModelSelection(menu=msg['data'], folder=msg['path']) sendModelSelection(menu=msg['data'], folder=msg['path'])
else: else:
vars.model = msg['data'] vars.model_selected = msg['data']
if 'path' in msg: if 'path' in msg:
vars.custmodpth = msg['path'] vars.custmodpth = msg['path']
get_model_info(msg['data'], directory=msg['path']) get_model_info(msg['data'], directory=msg['path'])
else: else:
get_model_info(vars.model) get_model_info(vars.model_selected)
elif(msg['cmd'] == 'delete_model'): elif(msg['cmd'] == 'delete_model'):
if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']): if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']):
if check_if_dir_is_model(msg['data']): if check_if_dir_is_model(msg['data']):