GooseAI Fixes

This commit is contained in:
ebolam
2023-05-26 11:08:30 -04:00
parent d2c95bc60f
commit 2c82e9c5e0
5 changed files with 22 additions and 12 deletions

View File

@@ -32,7 +32,7 @@ class model_backend(InferenceModel):
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
return model_name == "API" return model_name == "API"
def get_requested_parameters(self, model_name, model_path, menu_path): def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self): if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self):
with open("settings/api.model_backend.settings", "r") as f: with open("settings/api.model_backend.settings", "r") as f:
self.base_url = json.load(f)['base_url'] self.base_url = json.load(f)['base_url']

View File

@@ -33,7 +33,7 @@ class model_backend(InferenceModel):
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
return model_name == "Colab" return model_name == "Colab"
def get_requested_parameters(self, model_name, model_path, menu_path): def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
if os.path.exists("settings/api.model_backend.settings") and 'colaburl' not in vars(self): if os.path.exists("settings/api.model_backend.settings") and 'colaburl' not in vars(self):
with open("settings/api.model_backend.settings", "r") as f: with open("settings/api.model_backend.settings", "r") as f:
self.colaburl = json.load(f)['base_url'] self.colaburl = json.load(f)['base_url']

View File

@@ -39,19 +39,23 @@ class model_backend(InferenceModel):
logger.debug("Horde Models: {}".format(self.models)) logger.debug("Horde Models: {}".format(self.models))
return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models] return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models]
def get_requested_parameters(self, model_name, model_path, menu_path): def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self): if os.path.exists("settings/api.model_backend.settings") and 'base_url' not in vars(self):
with open("settings/horde.model_backend.settings", "r") as f: with open("settings/horde.model_backend.settings", "r") as f:
temp = json.load(f) temp = json.load(f)
self.base_url = temp['url'] self.base_url = temp['url']
self.key = temp['key'] self.key = temp['key']
if 'key' in parameters:
self.key = parameters['key']
if 'url' in parameters:
self.url = parameters['url']
requested_parameters = [] requested_parameters = []
requested_parameters.extend([{ requested_parameters.extend([{
"uitype": "text", "uitype": "text",
"unit": "text", "unit": "text",
"label": "URL", "label": "URL",
"id": "url", "id": "url",
"default": self.url, "default": self.url if 'url' not in parameters else parameters['url'],
"tooltip": "URL to the horde.", "tooltip": "URL to the horde.",
"menu_path": "", "menu_path": "",
"check": {"value": "", 'check': "!="}, "check": {"value": "", 'check': "!="},
@@ -63,7 +67,7 @@ class model_backend(InferenceModel):
"unit": "text", "unit": "text",
"label": "Key", "label": "Key",
"id": "key", "id": "key",
"default": self.key, "default": self.key if 'key' not in parameters else parameters['key'],
"check": {"value": "", 'check': "!="}, "check": {"value": "", 'check': "!="},
"tooltip": "User Key to use when connecting to Horde (0000000000 is anonymous).", "tooltip": "User Key to use when connecting to Horde (0000000000 is anonymous).",
"menu_path": "", "menu_path": "",

View File

@@ -1,5 +1,5 @@
import torch import torch
import requests import requests,json
import numpy as np import numpy as np
from typing import List, Optional, Union from typing import List, Optional, Union
import os import os
@@ -30,10 +30,15 @@ class model_backend(InferenceModel):
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
return model_name == "OAI" or model_name == "GooseAI" return model_name == "OAI" or model_name == "GooseAI"
def get_requested_parameters(self, model_name, model_path, menu_path): def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self): if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self):
with open("settings/{}.model_backend.settings".format(self.source), "r") as f: with open("settings/{}.model_backend.settings".format(self.source), "r") as f:
self.key = json.load(f)['key'] try:
self.key = json.load(f)['key']
except:
pass
if 'key' in parameters:
self.key = parameters['key']
self.source = model_name self.source = model_name
requested_parameters = [] requested_parameters = []
requested_parameters.extend([{ requested_parameters.extend([{
@@ -66,7 +71,7 @@ class model_backend(InferenceModel):
def set_input_parameters(self, parameters): def set_input_parameters(self, parameters):
self.key = parameters['key'].strip() self.key = parameters['key'].strip()
self.model = parameters['model'] self.model_name = parameters['model']
def get_oai_models(self): def get_oai_models(self):
if self.key == "": if self.key == "":
@@ -94,6 +99,7 @@ class model_backend(InferenceModel):
logger.init_ok("OAI Engines", status="OK") logger.init_ok("OAI Engines", status="OK")
logger.debug("OAI Engines: {}".format(engines))
return engines return engines
else: else:
# Something went wrong, print the message and quit since we can't initialize an engine # Something went wrong, print the message and quit since we can't initialize an engine
@@ -134,7 +140,7 @@ class model_backend(InferenceModel):
# Build request JSON data # Build request JSON data
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround # GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
# as the koboldai_vars.model will always be OAI # as the koboldai_vars.model will always be OAI
if "GooseAI" in utils.koboldai_vars.configname: if self.source == "GooseAI":
reqdata = { reqdata = {
"prompt": decoded_prompt, "prompt": decoded_prompt,
"max_tokens": max_new, "max_tokens": max_new,
@@ -163,7 +169,7 @@ class model_backend(InferenceModel):
} }
req = requests.post( req = requests.post(
self.url, "{}/{}/completions".format(self.url, self.model_name),
json=reqdata, json=reqdata,
headers={ headers={
"Authorization": "Bearer " + self.key, "Authorization": "Bearer " + self.key,

View File

@@ -33,7 +33,7 @@ class model_backend(InferenceModel):
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
return model_name == "ReadOnly" return model_name == "ReadOnly"
def get_requested_parameters(self, model_name, model_path, menu_path): def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
requested_parameters = [] requested_parameters = []
return requested_parameters return requested_parameters