diff --git a/aiserver.py b/aiserver.py index a09a0714..f8435737 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1431,12 +1431,28 @@ def get_model_info(model, directory=""): key_value = "" break_values = [] url = False + models_on_url = False + multi_online_models = False gpu_count = torch.cuda.device_count() gpu_names = [] for i in range(gpu_count): gpu_names.append(torch.cuda.get_device_name(i)) if model in ['Colab', 'API']: url = True + elif model == 'CLUSTER': + models_on_url = True + url = True + key = True + multi_online_models = True + if path.exists("settings/{}.settings".format(model)): + with open("settings/{}.settings".format(model), "r") as file: + # Check if API key exists + js = json.load(file) + if("apikey" in js and js["apikey"] != ""): + # API key exists, grab it and close the file + key_value = js["apikey"] + elif 'oaiapikey' in js and js['oaiapikey'] != "": + key_value = js["oaiapikey"] elif model in [x[1] for x in model_menu['apilist']]: if path.exists("settings/{}.settings".format(model)): with open("settings/{}.settings".format(model), "r") as file: @@ -1481,8 +1497,8 @@ def get_model_info(model, directory=""): emit('from_server', {'cmd': 'selected_model_info', 'key_value': key_value, 'key':key, 'gpu':gpu, 'layer_count':layer_count, 'breakmodel':breakmodel, 'disk_break_value': disk_blocks, 'accelerate': utils.HAS_ACCELERATE, - 'break_values': break_values, 'gpu_count': gpu_count, - 'url': url, 'gpu_names': gpu_names}, broadcast=True) + 'break_values': break_values, 'gpu_count': gpu_count, 'multi_online_models': multi_online_models, + 'url': url, 'gpu_names': gpu_names, 'models_on_url': models_on_url}, broadcast=True) if key_value != "": get_oai_models(key_value) @@ -1551,7 +1567,10 @@ def get_oai_models(key): if "apikey" in js: if js['apikey'] != key: changed=True + else: + changed=True if changed: + js={} with open("settings/{}.settings".format(vars.model_selected), "w") as file: js["apikey"] = key file.write(json.dumps(js, indent=3)) @@ -1563,6 +1582,55 @@ def get_oai_models(key): print(req.json()) emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) +def get_cluster_models(msg): + vars.oaiapikey = msg['key'] + vars.apikey = vars.oaiapikey + url = msg['url'] + + + # Get list of models from public cluster + print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="") + req = requests.get("{}/models".format(url)) + if(req.status_code == 200): + engines = req.json() + print(engines) + try: + engines = [[en, en] for en in engines] + except: + print(engines) + raise + print(engines) + + online_model = "" + changed=False + + #Save the key + if not path.exists("settings"): + # If the client settings file doesn't exist, create it + # Write API key to file + os.makedirs('settings', exist_ok=True) + if path.exists("settings/{}.settings".format(vars.model_selected)): + with open("settings/{}.settings".format(vars.model_selected), "r") as file: + js = json.load(file) + if 'online_model' in js: + online_model = js['online_model'] + if "apikey" in js: + if js['apikey'] != vars.oaiapikey: + changed=True + else: + changed=True + if changed: + js={} + with open("settings/{}.settings".format(vars.model_selected), "w") as file: + js["apikey"] = vars.oaiapikey + file.write(json.dumps(js, indent=3)) + + emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True) + else: + # Something went wrong, print the message and quit since we can't initialize an engine + print("{0}ERROR!{1}".format(colors.RED, colors.END)) + print(req.json()) + emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) # Function to patch transformers to use our soft prompt def patch_causallm(model): @@ -3700,6 +3768,8 @@ def get_message(msg): elif(msg['cmd'] == 'list_model'): sendModelSelection(menu=msg['data']) elif(msg['cmd'] == 'load_model'): + print(msg) + print(vars.model_selected) if not os.path.exists("settings/"): os.mkdir("settings") changed = True @@ -3723,6 +3793,14 @@ def get_message(msg): f.close() vars.colaburl = msg['url'] + "/request" vars.model = vars.model_selected + if vars.model == "CLUSTER": + if type(msg['online_model']) is not list: + if msg['online_model'] == '': + vars.cluster_requested_models = [] + else: + vars.cluster_requested_models = [msg['online_model']] + else: + vars.cluster_requested_models = msg['online_model'] load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model']) elif(msg['cmd'] == 'show_model'): print("Model Name: {}".format(getmodelname())) @@ -3786,6 +3864,8 @@ def get_message(msg): print(colors.RED + "WARNING!!: Someone maliciously attempted to delete " + msg['data'] + " the attempt has been blocked.") elif(msg['cmd'] == 'OAI_Key_Update'): get_oai_models(msg['key']) + elif(msg['cmd'] == 'Cluster_Key_Update'): + get_cluster_models(msg) elif(msg['cmd'] == 'loadselect'): vars.loadselect = msg["data"] elif(msg['cmd'] == 'spselect'): diff --git a/static/application.js b/static/application.js index 9107e161..d402b4a8 100644 --- a/static/application.js +++ b/static/application.js @@ -2918,12 +2918,30 @@ $(document).ready(function(){ if (msg.key) { $("#modelkey").removeClass("hidden"); $("#modelkey")[0].value = msg.key_value; + if (msg.models_on_url) { + $("#modelkey")[0].onblur = function () {socket.send({'cmd': 'Cluster_Key_Update', 'key': this.value, 'url': document.getElementById("modelurl").value});}; + $("#modelurl")[0].onblur = function () {socket.send({'cmd': 'Cluster_Key_Update', 'key': document.getElementById("modelkey").value, 'url': this.value});}; + } else { + $("#modelkey")[0].onblur = function () {socket.send({'cmd': 'OAI_Key_Update', 'key': $('#modelkey')[0].value});}; + $("#modelurl")[0].onblur = null; + } //if we're in the API list, disable to load button until the model is selected (after the API Key is entered) disableButtons([load_model_accept]); } else { $("#modelkey").addClass("hidden"); - } + + console.log(msg.multi_online_models); + if (msg.multi_online_models) { + $("#oaimodel")[0].setAttribute("multiple", ""); + $("#oaimodel")[0].options[0].textContent = "All" + } else { + $("#oaimodel")[0].removeAttribute("multiple"); + $("#oaimodel")[0].options[0].textContent = "Select Model(s)" + } + + + if (msg.url) { $("#modelurl").removeClass("hidden"); } else { diff --git a/templates/index.html b/templates/index.html index 27b50b78..2880914f 100644 --- a/templates/index.html +++ b/templates/index.html @@ -295,12 +295,12 @@
- +
- +