From 2c82e9c5e0fe0903f16291bcdb3816427a5af7f2 Mon Sep 17 00:00:00 2001 From: ebolam Date: Fri, 26 May 2023 11:08:30 -0400 Subject: [PATCH] GooseAI Fixes --- modeling/inference_models/api/class.py | 2 +- modeling/inference_models/basic_api/class.py | 2 +- modeling/inference_models/horde/class.py | 10 +++++++--- modeling/inference_models/openai_gooseai.py | 18 ++++++++++++------ modeling/inference_models/readonly/class.py | 2 +- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/modeling/inference_models/api/class.py b/modeling/inference_models/api/class.py index 3d54edd9..b3129d5a 100644 --- a/modeling/inference_models/api/class.py +++ b/modeling/inference_models/api/class.py @@ -32,7 +32,7 @@ class model_backend(InferenceModel): def is_valid(self, model_name, model_path, menu_path): return model_name == "API" - def get_requested_parameters(self, model_name, model_path, menu_path): + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self): with open("settings/api.model_backend.settings", "r") as f: self.base_url = json.load(f)['base_url'] diff --git a/modeling/inference_models/basic_api/class.py b/modeling/inference_models/basic_api/class.py index 2094d34e..b492c039 100644 --- a/modeling/inference_models/basic_api/class.py +++ b/modeling/inference_models/basic_api/class.py @@ -33,7 +33,7 @@ class model_backend(InferenceModel): def is_valid(self, model_name, model_path, menu_path): return model_name == "Colab" - def get_requested_parameters(self, model_name, model_path, menu_path): + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): if os.path.exists("settings/api.model_backend.settings") and 'colaburl' not in vars(self): with open("settings/api.model_backend.settings", "r") as f: self.colaburl = json.load(f)['base_url'] diff --git a/modeling/inference_models/horde/class.py b/modeling/inference_models/horde/class.py index 3b102b46..2cc01708 100644 --- a/modeling/inference_models/horde/class.py +++ b/modeling/inference_models/horde/class.py @@ -39,19 +39,23 @@ class model_backend(InferenceModel): logger.debug("Horde Models: {}".format(self.models)) return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models] - def get_requested_parameters(self, model_name, model_path, menu_path): + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self): with open("settings/horde.model_backend.settings", "r") as f: temp = json.load(f) self.base_url = temp['url'] self.key = temp['key'] + if 'key' in parameters: + self.key = parameters['key'] + if 'url' in parameters: + self.url = parameters['url'] requested_parameters = [] requested_parameters.extend([{ "uitype": "text", "unit": "text", "label": "URL", "id": "url", - "default": self.url, + "default": self.url if 'url' not in parameters else parameters['url'], "tooltip": "URL to the horde.", "menu_path": "", "check": {"value": "", 'check': "!="}, @@ -63,7 +67,7 @@ class model_backend(InferenceModel): "unit": "text", "label": "Key", "id": "key", - "default": self.key, + "default": self.key if 'key' not in parameters else parameters['key'], "check": {"value": "", 'check': "!="}, "tooltip": "User Key to use when connecting to Horde (0000000000 is anonymous).", "menu_path": "", diff --git a/modeling/inference_models/openai_gooseai.py b/modeling/inference_models/openai_gooseai.py index e4b9dfb8..0195f650 100644 --- a/modeling/inference_models/openai_gooseai.py +++ b/modeling/inference_models/openai_gooseai.py @@ -1,5 +1,5 @@ import torch -import requests +import requests,json import numpy as np from typing import List, Optional, Union import os @@ -30,10 +30,15 @@ class model_backend(InferenceModel): def is_valid(self, model_name, model_path, menu_path): return model_name == "OAI" or model_name == "GooseAI" - def get_requested_parameters(self, model_name, model_path, menu_path): + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self): with open("settings/{}.model_backend.settings".format(self.source), "r") as f: - self.key = json.load(f)['key'] + try: + self.key = json.load(f)['key'] + except: + pass + if 'key' in parameters: + self.key = parameters['key'] self.source = model_name requested_parameters = [] requested_parameters.extend([{ @@ -66,7 +71,7 @@ class model_backend(InferenceModel): def set_input_parameters(self, parameters): self.key = parameters['key'].strip() - self.model = parameters['model'] + self.model_name = parameters['model'] def get_oai_models(self): if self.key == "": @@ -94,6 +99,7 @@ class model_backend(InferenceModel): logger.init_ok("OAI Engines", status="OK") + logger.debug("OAI Engines: {}".format(engines)) return engines else: # Something went wrong, print the message and quit since we can't initialize an engine @@ -134,7 +140,7 @@ class model_backend(InferenceModel): # Build request JSON data # GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround # as the koboldai_vars.model will always be OAI - if "GooseAI" in utils.koboldai_vars.configname: + if self.source == "GooseAI": reqdata = { "prompt": decoded_prompt, "max_tokens": max_new, @@ -163,7 +169,7 @@ class model_backend(InferenceModel): } req = requests.post( - self.url, + "{}/{}/completions".format(self.url, self.model_name), json=reqdata, headers={ "Authorization": "Bearer " + self.key, diff --git a/modeling/inference_models/readonly/class.py b/modeling/inference_models/readonly/class.py index 92531af4..98573990 100644 --- a/modeling/inference_models/readonly/class.py +++ b/modeling/inference_models/readonly/class.py @@ -33,7 +33,7 @@ class model_backend(InferenceModel): def is_valid(self, model_name, model_path, menu_path): return model_name == "ReadOnly" - def get_requested_parameters(self, model_name, model_path, menu_path): + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): requested_parameters = [] return requested_parameters