diff --git a/aiserver.py b/aiserver.py index bf4cfac7..ab4fe521 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1422,12 +1422,17 @@ def get_model_info(model, directory=""): key_value = "" break_values = [] url = False + models_on_url = False gpu_count = torch.cuda.device_count() gpu_names = [] for i in range(gpu_count): gpu_names.append(torch.cuda.get_device_name(i)) if model in ['Colab', 'API']: url = True + elif model == 'CLUSTER': + models_on_url = True + url = True + key = True elif model in [x[1] for x in model_menu['apilist']]: if path.exists("settings/{}.settings".format(model)): with open("settings/{}.settings".format(model), "r") as file: @@ -1473,7 +1478,7 @@ def get_model_info(model, directory=""): 'gpu':gpu, 'layer_count':layer_count, 'breakmodel':breakmodel, 'disk_break_value': disk_blocks, 'accelerate': utils.HAS_ACCELERATE, 'break_values': break_values, 'gpu_count': gpu_count, - 'url': url, 'gpu_names': gpu_names}, broadcast=True) + 'url': url, 'gpu_names': gpu_names, 'models_on_url': models_on_url}, broadcast=True) if key_value != "": get_oai_models(key_value) @@ -1554,6 +1559,54 @@ def get_oai_models(key): print(req.json()) emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) +def get_cluster_models(msg): + vars.oaiapikey = msg['key'] + url = msg['url'] + + + # Get list of models from OAI + print("{0}Retrieving engine list...{1}".format(colors.PURPLE, colors.END), end="") + req = requests.get( + url, + headers = { + 'Authorization': 'Bearer '+key + } + ) + if(req.status_code == 200): + engines = req.json()["data"] + try: + engines = [[en["id"], "{} ({})".format(en['id'], "Ready" if en["ready"] == True else "Not Ready")] for en in engines] + except: + print(engines) + raise + + online_model = "" + changed=False + + #Save the key + if not path.exists("settings"): + # If the client settings file doesn't exist, create it + # Write API key to file + os.makedirs('settings', exist_ok=True) + if path.exists("settings/{}.settings".format(vars.model_selected)): + with open("settings/{}.settings".format(vars.model_selected), "r") as file: + js = json.load(file) + if 'online_model' in js: + online_model = js['online_model'] + if "apikey" in js: + if js['apikey'] != key: + changed=True + if changed: + with open("settings/{}.settings".format(vars.model_selected), "w") as file: + js["apikey"] = key + file.write(json.dumps(js, indent=3)) + + emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True) + else: + # Something went wrong, print the message and quit since we can't initialize an engine + print("{0}ERROR!{1}".format(colors.RED, colors.END)) + print(req.json()) + emit('from_server', {'cmd': 'errmsg', 'data': req.json()}) # Function to patch transformers to use our soft prompt def patch_causallm(model): @@ -3777,6 +3830,8 @@ def get_message(msg): print(colors.RED + "WARNING!!: Someone maliciously attempted to delete " + msg['data'] + " the attempt has been blocked.") elif(msg['cmd'] == 'OAI_Key_Update'): get_oai_models(msg['key']) + elif(msg['cmd'] == 'Cluster_Key_Update'): + get_cluster_models(msg) elif(msg['cmd'] == 'loadselect'): vars.loadselect = msg["data"] elif(msg['cmd'] == 'spselect'): diff --git a/static/application.js b/static/application.js index 9107e161..48bf595a 100644 --- a/static/application.js +++ b/static/application.js @@ -2918,6 +2918,11 @@ $(document).ready(function(){ if (msg.key) { $("#modelkey").removeClass("hidden"); $("#modelkey")[0].value = msg.key_value; + if (msg.models_on_url) { + $("#modelkey").onblur = function () {socket.send({'cmd': 'Cluster_Key_Update', 'key': this.value, 'url': ${'modelurl')[].value});}; + } else { + $("#modelkey").onblur = function () {socket.send({'cmd': 'OAI_Key_Update', 'key': $('#modelkey')[0].value});}; + } //if we're in the API list, disable to load button until the model is selected (after the API Key is entered) disableButtons([load_model_accept]); } else {