Merge pull request #200 from db0/online_model_fix
Fixes Horde not saving as expected
This commit is contained in:
commit
c7a6309fa2
96
aiserver.py
96
aiserver.py
|
@ -239,7 +239,8 @@ class vars:
|
||||||
lastact = "" # The last action received from the user
|
lastact = "" # The last action received from the user
|
||||||
submission = "" # Same as above, but after applying input formatting
|
submission = "" # Same as above, but after applying input formatting
|
||||||
lastctx = "" # The last context submitted to the generator
|
lastctx = "" # The last context submitted to the generator
|
||||||
model = "" # Model ID string chosen at startup
|
model = "ReadOnly" # Model ID string chosen at startup
|
||||||
|
online_model = "" # Used when Model ID is an online service, and there is a secondary option for the actual model name
|
||||||
model_selected = "" #selected model in UI
|
model_selected = "" #selected model in UI
|
||||||
model_type = "" # Model Type (Automatically taken from the model config)
|
model_type = "" # Model Type (Automatically taken from the model config)
|
||||||
noai = False # Runs the script without starting up the transformers pipeline
|
noai = False # Runs the script without starting up the transformers pipeline
|
||||||
|
@ -380,6 +381,7 @@ class vars:
|
||||||
output_streaming = True
|
output_streaming = True
|
||||||
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
|
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
|
||||||
show_probs = False # Whether or not to show token probabilities
|
show_probs = False # Whether or not to show token probabilities
|
||||||
|
configname = None
|
||||||
|
|
||||||
utils.vars = vars
|
utils.vars = vars
|
||||||
|
|
||||||
|
@ -615,6 +617,18 @@ api_v1 = KoboldAPISpec(
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Returns the expected config filename for the current setup.
|
||||||
|
# If the model_name is specified, it returns what the settings file would be for that model
|
||||||
|
def get_config_filename(model_name = None):
|
||||||
|
if model_name:
|
||||||
|
return(f"settings/{model_name.replace('/', '_')}.settings")
|
||||||
|
elif args.configname:
|
||||||
|
return(f"settings/{args.configname}.settings")
|
||||||
|
elif vars.configname != '':
|
||||||
|
return(f"settings/{vars.configname}.settings")
|
||||||
|
else:
|
||||||
|
print(f"Empty configfile name sent back. Defaulting to ReadOnly")
|
||||||
|
return(f"settings/ReadOnly.settings")
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Function to get model selection at startup
|
# Function to get model selection at startup
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -722,9 +736,8 @@ def check_if_dir_is_model(path):
|
||||||
# Return Model Name
|
# Return Model Name
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def getmodelname():
|
def getmodelname():
|
||||||
if(args.configname):
|
if(vars.online_model != ''):
|
||||||
modelname = args.configname
|
return(f"{vars.model}/{vars.online_model}")
|
||||||
return modelname
|
|
||||||
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
|
||||||
return modelname
|
return modelname
|
||||||
|
@ -1058,7 +1071,7 @@ def savesettings():
|
||||||
# Write it
|
# Write it
|
||||||
if not os.path.exists('settings'):
|
if not os.path.exists('settings'):
|
||||||
os.mkdir('settings')
|
os.mkdir('settings')
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
|
file = open(get_config_filename(), "w")
|
||||||
try:
|
try:
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
finally:
|
finally:
|
||||||
|
@ -1084,9 +1097,9 @@ def loadsettings():
|
||||||
|
|
||||||
processsettings(js)
|
processsettings(js)
|
||||||
file.close()
|
file.close()
|
||||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
if(path.exists(get_config_filename())):
|
||||||
# Read file contents into JSON object
|
# Read file contents into JSON object
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
file = open(get_config_filename(), "r")
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
|
|
||||||
processsettings(js)
|
processsettings(js)
|
||||||
|
@ -1444,8 +1457,8 @@ def get_model_info(model, directory=""):
|
||||||
url = True
|
url = True
|
||||||
key = True
|
key = True
|
||||||
multi_online_models = True
|
multi_online_models = True
|
||||||
if path.exists("settings/{}.settings".format(model)):
|
if path.exists(get_config_filename(model)):
|
||||||
with open("settings/{}.settings".format(model), "r") as file:
|
with open(get_config_filename(model), "r") as file:
|
||||||
# Check if API key exists
|
# Check if API key exists
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if("apikey" in js and js["apikey"] != ""):
|
if("apikey" in js and js["apikey"] != ""):
|
||||||
|
@ -1454,8 +1467,8 @@ def get_model_info(model, directory=""):
|
||||||
elif 'oaiapikey' in js and js['oaiapikey'] != "":
|
elif 'oaiapikey' in js and js['oaiapikey'] != "":
|
||||||
key_value = js["oaiapikey"]
|
key_value = js["oaiapikey"]
|
||||||
elif model in [x[1] for x in model_menu['apilist']]:
|
elif model in [x[1] for x in model_menu['apilist']]:
|
||||||
if path.exists("settings/{}.settings".format(model)):
|
if path.exists(get_config_filename(model)):
|
||||||
with open("settings/{}.settings".format(model), "r") as file:
|
with open(get_config_filename(model), "r") as file:
|
||||||
# Check if API key exists
|
# Check if API key exists
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if("apikey" in js and js["apikey"] != ""):
|
if("apikey" in js and js["apikey"] != ""):
|
||||||
|
@ -1559,8 +1572,8 @@ def get_oai_models(key):
|
||||||
# If the client settings file doesn't exist, create it
|
# If the client settings file doesn't exist, create it
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
os.makedirs('settings', exist_ok=True)
|
os.makedirs('settings', exist_ok=True)
|
||||||
if path.exists("settings/{}.settings".format(vars.model_selected)):
|
if path.exists(get_config_filename(vars.model_selected)):
|
||||||
with open("settings/{}.settings".format(vars.model_selected), "r") as file:
|
with open(get_config_filename(vars.model_selected), "r") as file:
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if 'online_model' in js:
|
if 'online_model' in js:
|
||||||
online_model = js['online_model']
|
online_model = js['online_model']
|
||||||
|
@ -1571,7 +1584,7 @@ def get_oai_models(key):
|
||||||
changed=True
|
changed=True
|
||||||
if changed:
|
if changed:
|
||||||
js={}
|
js={}
|
||||||
with open("settings/{}.settings".format(vars.model_selected), "w") as file:
|
with open(get_config_filename(vars.model_selected), "w") as file:
|
||||||
js["apikey"] = key
|
js["apikey"] = key
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
|
|
||||||
|
@ -1609,8 +1622,8 @@ def get_cluster_models(msg):
|
||||||
# If the client settings file doesn't exist, create it
|
# If the client settings file doesn't exist, create it
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
os.makedirs('settings', exist_ok=True)
|
os.makedirs('settings', exist_ok=True)
|
||||||
if path.exists("settings/{}.settings".format(vars.model_selected)):
|
if path.exists(get_config_filename(vars.model_selected)):
|
||||||
with open("settings/{}.settings".format(vars.model_selected), "r") as file:
|
with open(get_config_filename(vars.model_selected), "r") as file:
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if 'online_model' in js:
|
if 'online_model' in js:
|
||||||
online_model = js['online_model']
|
online_model = js['online_model']
|
||||||
|
@ -1621,7 +1634,7 @@ def get_cluster_models(msg):
|
||||||
changed=True
|
changed=True
|
||||||
if changed:
|
if changed:
|
||||||
js={}
|
js={}
|
||||||
with open("settings/{}.settings".format(vars.model_selected), "w") as file:
|
with open(get_config_filename(vars.model_selected), "w") as file:
|
||||||
js["apikey"] = vars.oaiapikey
|
js["apikey"] = vars.oaiapikey
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
|
|
||||||
|
@ -2065,6 +2078,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
model = None
|
model = None
|
||||||
generator = None
|
generator = None
|
||||||
model_config = None
|
model_config = None
|
||||||
|
vars.online_model = ''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
|
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
|
||||||
|
@ -2083,11 +2097,26 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
#Reload our badwords
|
#Reload our badwords
|
||||||
vars.badwordsids = vars.badwordsids_default
|
vars.badwordsids = vars.badwordsids_default
|
||||||
|
|
||||||
|
if online_model == "":
|
||||||
|
vars.configname = vars.model
|
||||||
#Let's set the GooseAI or OpenAI server URLs if that's applicable
|
#Let's set the GooseAI or OpenAI server URLs if that's applicable
|
||||||
if online_model != "":
|
else:
|
||||||
if path.exists("settings/{}.settings".format(vars.model)):
|
vars.online_model = online_model
|
||||||
|
# Swap OAI Server if GooseAI was selected
|
||||||
|
if(vars.model == "GooseAI"):
|
||||||
|
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
||||||
|
vars.model = "OAI"
|
||||||
|
vars.configname = f"GooseAI_{online_model.replace('/', '_')}"
|
||||||
|
elif(vars.model == "CLUSTER") and type(online_model) is list:
|
||||||
|
if len(online_model) != 1:
|
||||||
|
vars.configname = vars.model
|
||||||
|
else:
|
||||||
|
vars.configname = f"{vars.model}_{online_model[0].replace('/', '_')}"
|
||||||
|
else:
|
||||||
|
vars.configname = f"{vars.model}_{online_model.replace('/', '_')}"
|
||||||
|
if path.exists(get_config_filename()):
|
||||||
changed=False
|
changed=False
|
||||||
with open("settings/{}.settings".format(vars.model), "r") as file:
|
with open(get_config_filename(), "r") as file:
|
||||||
# Check if API key exists
|
# Check if API key exists
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if 'online_model' in js:
|
if 'online_model' in js:
|
||||||
|
@ -2098,15 +2127,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
changed=True
|
changed=True
|
||||||
js['online_model'] = online_model
|
js['online_model'] = online_model
|
||||||
if changed:
|
if changed:
|
||||||
with open("settings/{}.settings".format(vars.model), "w") as file:
|
with open(get_config_filename(), "w") as file:
|
||||||
file.write(json.dumps(js, indent=3))
|
file.write(json.dumps(js, indent=3))
|
||||||
# Swap OAI Server if GooseAI was selected
|
|
||||||
if(vars.model == "GooseAI"):
|
|
||||||
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
|
||||||
vars.model = "OAI"
|
|
||||||
args.configname = "GooseAI" + "/" + online_model
|
|
||||||
else:
|
|
||||||
args.configname = vars.model + "/" + online_model
|
|
||||||
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(online_model)
|
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(online_model)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2193,12 +2215,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
if(vars.model == "GooseAI"):
|
if(vars.model == "GooseAI"):
|
||||||
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
||||||
vars.model = "OAI"
|
vars.model = "OAI"
|
||||||
args.configname = "GooseAI"
|
vars.configname = "GooseAI"
|
||||||
|
|
||||||
# Ask for API key if OpenAI was selected
|
# Ask for API key if OpenAI was selected
|
||||||
if(vars.model == "OAI"):
|
if(vars.model == "OAI"):
|
||||||
if not args.configname:
|
if not vars.configname:
|
||||||
args.configname = "OAI"
|
vars.configname = "OAI"
|
||||||
|
|
||||||
if(vars.model == "ReadOnly"):
|
if(vars.model == "ReadOnly"):
|
||||||
vars.noai = True
|
vars.noai = True
|
||||||
|
@ -2784,8 +2806,8 @@ def lua_startup():
|
||||||
global _bridged
|
global _bridged
|
||||||
global F
|
global F
|
||||||
global bridged
|
global bridged
|
||||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
if(path.exists(get_config_filename())):
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
file = open(get_config_filename(), "r")
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if("userscripts" in js):
|
if("userscripts" in js):
|
||||||
vars.userscripts = []
|
vars.userscripts = []
|
||||||
|
@ -6052,7 +6074,9 @@ def oairequest(txt, min, max):
|
||||||
vars.lastctx = txt
|
vars.lastctx = txt
|
||||||
|
|
||||||
# Build request JSON data
|
# Build request JSON data
|
||||||
if 'GooseAI' in args.configname:
|
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||||
|
# as the vars.model will always be OAI
|
||||||
|
if 'GooseAI' in vars.configname:
|
||||||
reqdata = {
|
reqdata = {
|
||||||
'prompt': txt,
|
'prompt': txt,
|
||||||
'max_tokens': vars.genamt,
|
'max_tokens': vars.genamt,
|
||||||
|
@ -6882,8 +6906,8 @@ def final_startup():
|
||||||
threading.Thread(target=__preempt_tokenizer).start()
|
threading.Thread(target=__preempt_tokenizer).start()
|
||||||
|
|
||||||
# Load soft prompt specified by the settings file, if applicable
|
# Load soft prompt specified by the settings file, if applicable
|
||||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
if(path.exists(get_config_filename())):
|
||||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
file = open(get_config_filename(), "r")
|
||||||
js = json.load(file)
|
js = json.load(file)
|
||||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||||
spRequest(js["softprompt"])
|
spRequest(js["softprompt"])
|
||||||
|
|
Loading…
Reference in New Issue