Merge commit 'refs/pull/107/head' of https://github.com/ebolam/KoboldAI into united

This commit is contained in:
ebolam 2022-09-05 19:49:32 -04:00
commit a383ef81b1
1 changed files with 67 additions and 33 deletions

View File

@ -239,7 +239,8 @@ class vars:
lastact = "" # The last action received from the user
submission = "" # Same as above, but after applying input formatting
lastctx = "" # The last context submitted to the generator
model = "" # Model ID string chosen at startup
model = "ReadOnly" # Model ID string chosen at startup
online_model = "" # Used when Model ID is an online service, and there is a secondary option for the actual model name
model_selected = "" #selected model in UI
model_type = "" # Model Type (Automatically taken from the model config)
noai = False # Runs the script without starting up the transformers pipeline
@ -380,6 +381,7 @@ class vars:
output_streaming = True
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
show_probs = False # Whether or not to show token probabilities
configname = None
utils.vars = vars
@ -615,6 +617,18 @@ api_v1 = KoboldAPISpec(
tags=tags,
)
# Returns the expected config filename for the current setup.
# If the model_name is specified, it returns what the settings file would be for that model
def get_config_filename(model_name = None):
if model_name:
return(f"settings/{model_name.replace('/', '_')}.settings")
elif args.configname:
return(f"settings/{args.configname.replace('/', '_')}.settings")
elif vars.configname != '':
return(f"settings/{vars.configname.replace('/', '_')}.settings")
else:
print(f"Empty configfile name sent back. Defaulting to ReadOnly")
return(f"settings/ReadOnly.settings")
#==================================================================#
# Function to get model selection at startup
#==================================================================#
@ -722,9 +736,8 @@ def check_if_dir_is_model(path):
# Return Model Name
#==================================================================#
def getmodelname():
if(args.configname):
modelname = args.configname
return modelname
if(vars.online_model != ''):
return(f"{vars.model}/{vars.online_model}")
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
return modelname
@ -1058,7 +1071,7 @@ def savesettings():
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
file = open(get_config_filename(), "w")
try:
file.write(json.dumps(js, indent=3))
finally:
@ -1084,9 +1097,9 @@ def loadsettings():
processsettings(js)
file.close()
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
if(path.exists(get_config_filename())):
# Read file contents into JSON object
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
file = open(get_config_filename(), "r")
js = json.load(file)
processsettings(js)
@ -1446,8 +1459,8 @@ def get_model_info(model, directory=""):
key = True
default_url = 'https://koboldai.net'
multi_online_models = True
if path.exists("settings/{}.settings".format(model)):
with open("settings/{}.settings".format(model), "r") as file:
if path.exists(get_config_filename(model)):
with open(get_config_filename(model), "r") as file:
# Check if API key exists
js = json.load(file)
if("apikey" in js and js["apikey"] != ""):
@ -1456,8 +1469,8 @@ def get_model_info(model, directory=""):
elif 'oaiapikey' in js and js['oaiapikey'] != "":
key_value = js["oaiapikey"]
elif model in [x[1] for x in model_menu['apilist']]:
if path.exists("settings/{}.settings".format(model)):
with open("settings/{}.settings".format(model), "r") as file:
if path.exists(get_config_filename(model)):
with open(get_config_filename(model), "r") as file:
# Check if API key exists
js = json.load(file)
if("apikey" in js and js["apikey"] != ""):
@ -1561,8 +1574,8 @@ def get_oai_models(key):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists("settings/{}.settings".format(vars.model_selected)):
with open("settings/{}.settings".format(vars.model_selected), "r") as file:
if path.exists(get_config_filename(vars.model_selected)):
with open(get_config_filename(vars.model_selected), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
@ -1573,7 +1586,7 @@ def get_oai_models(key):
changed=True
if changed:
js={}
with open("settings/{}.settings".format(vars.model_selected), "w") as file:
with open(get_config_filename(vars.model_selected), "w") as file:
js["apikey"] = key
file.write(json.dumps(js, indent=3))
@ -1611,8 +1624,8 @@ def get_cluster_models(msg):
# If the client settings file doesn't exist, create it
# Write API key to file
os.makedirs('settings', exist_ok=True)
if path.exists("settings/{}.settings".format(vars.model_selected)):
with open("settings/{}.settings".format(vars.model_selected), "r") as file:
if path.exists(get_config_filename(vars.model_selected)):
with open(get_config_filename(vars.model_selected), "r") as file:
js = json.load(file)
if 'online_model' in js:
online_model = js['online_model']
@ -1623,7 +1636,7 @@ def get_cluster_models(msg):
changed=True
if changed:
js={}
with open("settings/{}.settings".format(vars.model_selected), "w") as file:
with open(get_config_filename(vars.model_selected), "w") as file:
js["apikey"] = vars.oaiapikey
file.write(json.dumps(js, indent=3))
@ -2067,6 +2080,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
model = None
generator = None
model_config = None
vars.online_model = ''
with torch.no_grad():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
@ -2085,11 +2099,26 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
#Reload our badwords
vars.badwordsids = vars.badwordsids_default
if online_model == "":
vars.configname = vars.model.replace('/', '_')
#Let's set the GooseAI or OpenAI server URLs if that's applicable
if online_model != "":
if path.exists("settings/{}.settings".format(vars.model)):
else:
vars.online_model = online_model
# Swap OAI Server if GooseAI was selected
if(vars.model == "GooseAI"):
vars.oaiengines = "https://api.goose.ai/v1/engines"
vars.model = "OAI"
vars.configname = f"GooseAI_{online_model.replace('/', '_')}"
elif(vars.model == "CLUSTER") and type(online_model) is list:
if len(online_model) != 1:
vars.configname = vars.model
else:
vars.configname = f"{vars.model}_{online_model[0].replace('/', '_')}"
else:
vars.configname = f"{vars.model}_{online_model.replace('/', '_')}"
if path.exists(get_config_filename()):
changed=False
with open("settings/{}.settings".format(vars.model), "r") as file:
with open(get_config_filename(), "r") as file:
# Check if API key exists
js = json.load(file)
if 'online_model' in js:
@ -2100,8 +2129,9 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
changed=True
js['online_model'] = online_model
if changed:
with open("settings/{}.settings".format(vars.model), "w") as file:
with open(get_config_filename(), "w") as file:
file.write(json.dumps(js, indent=3))
<<<<<<< HEAD
# Swap OAI Server if GooseAI was selected
if(vars.model == "GooseAI"):
vars.oaiengines = "https://api.goose.ai/v1/engines"
@ -2109,6 +2139,8 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
args.configname = "GooseAI" + "/" + online_model
elif vars.model != "CLUSTER":
args.configname = vars.model + "/" + online_model
=======
>>>>>>> 296481f4aae46ce3d665537744460f1d3c0947a2
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(online_model)
@ -2195,12 +2227,12 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
if(vars.model == "GooseAI"):
vars.oaiengines = "https://api.goose.ai/v1/engines"
vars.model = "OAI"
args.configname = "GooseAI"
vars.configname = "GooseAI"
# Ask for API key if OpenAI was selected
if(vars.model == "OAI"):
if not args.configname:
args.configname = "OAI"
if not vars.configname:
vars.configname = "OAI"
if(vars.model == "ReadOnly"):
vars.noai = True
@ -2786,8 +2818,8 @@ def lua_startup():
global _bridged
global F
global bridged
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
if(path.exists(get_config_filename())):
file = open(get_config_filename(), "r")
js = json.load(file)
if("userscripts" in js):
vars.userscripts = []
@ -3847,7 +3879,7 @@ def get_message(msg):
else:
sendModelSelection(menu=msg['data'], folder=msg['path'])
else:
vars.model_selected = msg['data']
vars.model_selected = msg['data']
if 'path' in msg:
vars.custmodpth = msg['path']
get_model_info(msg['data'], directory=msg['path'])
@ -4061,9 +4093,9 @@ def check_for_backend_compilation():
break
vars.checking = False
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
# Ignore new submissions if the AI is currently busy
if(vars.aibusy):
if(not ignore_aibusy and vars.aibusy):
return
while(True):
@ -6054,7 +6086,9 @@ def oairequest(txt, min, max):
vars.lastctx = txt
# Build request JSON data
if 'GooseAI' in args.configname:
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
# as the vars.model will always be OAI
if 'GooseAI' in vars.configname:
reqdata = {
'prompt': txt,
'max_tokens': vars.genamt,
@ -6884,8 +6918,8 @@ def final_startup():
threading.Thread(target=__preempt_tokenizer).start()
# Load soft prompt specified by the settings file, if applicable
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
if(path.exists(get_config_filename())):
file = open(get_config_filename(), "r")
js = json.load(file)
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
spRequest(js["softprompt"])
@ -7757,7 +7791,7 @@ def post_story_end(body: SubmissionInputSchema):
numseqs = vars.numseqs
vars.numseqs = 1
try:
actionsubmit(body.prompt, force_submit=True, no_generate=True)
actionsubmit(body.prompt, force_submit=True, no_generate=True, ignore_aibusy=True)
finally:
vars.disable_set_aibusy = disable_set_aibusy
vars.standalone = _standalone