GooseAI Fixes

This commit is contained in:
ebolam
2023-05-26 11:08:30 -04:00
parent d2c95bc60f
commit 2c82e9c5e0
5 changed files with 22 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
import torch
import requests
import requests,json
import numpy as np
from typing import List, Optional, Union
import os
@@ -30,10 +30,15 @@ class model_backend(InferenceModel):
def is_valid(self, model_name, model_path, menu_path):
return model_name == "OAI" or model_name == "GooseAI"
def get_requested_parameters(self, model_name, model_path, menu_path):
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self):
with open("settings/{}.model_backend.settings".format(self.source), "r") as f:
self.key = json.load(f)['key']
try:
self.key = json.load(f)['key']
except:
pass
if 'key' in parameters:
self.key = parameters['key']
self.source = model_name
requested_parameters = []
requested_parameters.extend([{
@@ -66,7 +71,7 @@ class model_backend(InferenceModel):
def set_input_parameters(self, parameters):
self.key = parameters['key'].strip()
self.model = parameters['model']
self.model_name = parameters['model']
def get_oai_models(self):
if self.key == "":
@@ -94,6 +99,7 @@ class model_backend(InferenceModel):
logger.init_ok("OAI Engines", status="OK")
logger.debug("OAI Engines: {}".format(engines))
return engines
else:
# Something went wrong, print the message and quit since we can't initialize an engine
@@ -134,7 +140,7 @@ class model_backend(InferenceModel):
# Build request JSON data
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
# as the koboldai_vars.model will always be OAI
if "GooseAI" in utils.koboldai_vars.configname:
if self.source == "GooseAI":
reqdata = {
"prompt": decoded_prompt,
"max_tokens": max_new,
@@ -163,7 +169,7 @@ class model_backend(InferenceModel):
}
req = requests.post(
self.url,
"{}/{}/completions".format(self.url, self.model_name),
json=reqdata,
headers={
"Authorization": "Bearer " + self.key,