mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
GooseAI Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import requests
|
||||
import requests,json
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
import os
|
||||
@@ -30,10 +30,15 @@ class model_backend(InferenceModel):
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
return model_name == "OAI" or model_name == "GooseAI"
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path):
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
|
||||
if os.path.exists("settings/{}.model_backend.settings".format(self.source)) and 'colaburl' not in vars(self):
|
||||
with open("settings/{}.model_backend.settings".format(self.source), "r") as f:
|
||||
self.key = json.load(f)['key']
|
||||
try:
|
||||
self.key = json.load(f)['key']
|
||||
except:
|
||||
pass
|
||||
if 'key' in parameters:
|
||||
self.key = parameters['key']
|
||||
self.source = model_name
|
||||
requested_parameters = []
|
||||
requested_parameters.extend([{
|
||||
@@ -66,7 +71,7 @@ class model_backend(InferenceModel):
|
||||
|
||||
def set_input_parameters(self, parameters):
|
||||
self.key = parameters['key'].strip()
|
||||
self.model = parameters['model']
|
||||
self.model_name = parameters['model']
|
||||
|
||||
def get_oai_models(self):
|
||||
if self.key == "":
|
||||
@@ -94,6 +99,7 @@ class model_backend(InferenceModel):
|
||||
|
||||
|
||||
logger.init_ok("OAI Engines", status="OK")
|
||||
logger.debug("OAI Engines: {}".format(engines))
|
||||
return engines
|
||||
else:
|
||||
# Something went wrong, print the message and quit since we can't initialize an engine
|
||||
@@ -134,7 +140,7 @@ class model_backend(InferenceModel):
|
||||
# Build request JSON data
|
||||
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||
# as the koboldai_vars.model will always be OAI
|
||||
if "GooseAI" in utils.koboldai_vars.configname:
|
||||
if self.source == "GooseAI":
|
||||
reqdata = {
|
||||
"prompt": decoded_prompt,
|
||||
"max_tokens": max_new,
|
||||
@@ -163,7 +169,7 @@ class model_backend(InferenceModel):
|
||||
}
|
||||
|
||||
req = requests.post(
|
||||
self.url,
|
||||
"{}/{}/completions".format(self.url, self.model_name),
|
||||
json=reqdata,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self.key,
|
||||
|
Reference in New Issue
Block a user