diff --git a/aiserver.py b/aiserver.py index b15615a7..35b20613 100644 --- a/aiserver.py +++ b/aiserver.py @@ -434,13 +434,15 @@ def getModelSelection(modellist): getModelSelection(mainmenu) def check_if_dir_is_model(path): - try: - from transformers import AutoConfig - model_config = AutoConfig.from_pretrained(path, revision=vars.revision, cache_dir="cache") - except: + if os.path.exists(path): + try: + from transformers import AutoConfig + model_config = AutoConfig.from_pretrained(path, revision=vars.revision, cache_dir="cache") + except: + return False + return True + else: return False - return True - #==================================================================# # Return all keys in tokenizer dictionary containing char @@ -3056,12 +3058,24 @@ def get_message(msg): # The data variable will contain the model name. But our Custom lines need a bit more processing # If we're on a custom line that we have selected a model for, the path variable will be in msg # so if that's missing we need to run the menu to show the model folders in the models folder - if msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path' not in msg: + if msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path' not in msg and 'path_modelname' not in msg: if 'folder' not in msg: folder = "./models" else: folder = msg['folder'] sendModelSelection(menu=msg['data'], folder=folder) + elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg: + #Here the user entered custom text in the text box. This could be either a model name or a path. + if check_if_dir_is_model(msg['path_modelname']): + vars.model = msg['data'] + vars.custmodpth = msg['path_modelname'] + get_model_info(msg['data'], directory=msg['path']) + else: + vars.model = msg['path_modelname'] + try: + get_model_info(vars.model) + except: + emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."}) elif msg['data'] in ('NeoCustom', 'GPT2Custom'): if check_if_dir_is_model(msg['path']): vars.model = msg['data'] diff --git a/static/application.js b/static/application.js index f4366ed8..22b0e0e6 100644 --- a/static/application.js +++ b/static/application.js @@ -1002,6 +1002,7 @@ function buildLoadModelList(ar, menu, breadcrumbs) { disableButtons([load_model_accept]); loadmodelcontent.html(""); $("#loadmodellistbreadcrumbs").html(""); + $("#custommodelname").addClass("hidden"); var i; for(i=0; i"+breadcrumbs[i][1]+"\\"); @@ -1056,6 +1057,8 @@ function buildLoadModelList(ar, menu, breadcrumbs) { highlightLoadLine($(this)); } })(i)); + $("#custommodelname").removeClass("hidden"); + $("#custommodelname")[0].setAttribute("menu", menu); //Normal load } else { $("#loadmodel"+i).off("click").on("click", (function () { diff --git a/templates/index.html b/templates/index.html index 80a32484..2050001a 100644 --- a/templates/index.html +++ b/templates/index.html @@ -289,6 +289,7 @@
+