Basic GooseAI Support
This commit is contained in:
parent
f1b0ea711e
commit
3ddc9647eb
22
aiserver.py
22
aiserver.py
|
@ -126,6 +126,7 @@ xglmlist = [
|
|||
]
|
||||
|
||||
apilist = [
|
||||
["GooseAI API (requires API key)", "GooseAI", ""],
|
||||
["OpenAI API (requires API key)", "OAI", ""],
|
||||
["InferKit API (requires API key)", "InferKit", ""],
|
||||
["KoboldAI Server API (Old Google Colab)", "Colab", ""],
|
||||
|
@ -787,7 +788,7 @@ else:
|
|||
getModelSelection(mainmenu)
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
vars.allowsp = True
|
||||
# Test for GPU support
|
||||
import torch
|
||||
|
@ -827,7 +828,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
|
@ -926,12 +927,20 @@ if(vars.model == "InferKit"):
|
|||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
# Swap OAI Server if GooseAI was selected
|
||||
if(vars.model == "GooseAI"):
|
||||
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
||||
vars.model = "OAI"
|
||||
args.configname = "GooseAI"
|
||||
|
||||
# Ask for API key if OpenAI was selected
|
||||
if(vars.model == "OAI"):
|
||||
if not args.configname:
|
||||
args.configname = "OAI"
|
||||
if(not path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
# If the client settings file doesn't exist, create it
|
||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
print("{0}Please enter your API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.oaiapikey = input("Key> ")
|
||||
# Write API key to file
|
||||
os.makedirs('settings', exist_ok=True)
|
||||
|
@ -952,7 +961,7 @@ if(vars.model == "OAI"):
|
|||
file.close()
|
||||
else:
|
||||
# Get API key, add it to settings object, and write it to disk
|
||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
print("{0}Please enter your API key:{1}\n".format(colors.CYAN, colors.END))
|
||||
vars.oaiapikey = input("Key> ")
|
||||
js["oaiapikey"] = vars.oaiapikey
|
||||
# Write API key to file
|
||||
|
@ -985,7 +994,8 @@ if(vars.model == "OAI"):
|
|||
while(engselected == False):
|
||||
engine = input("Engine #> ")
|
||||
if(engine.isnumeric() and int(engine) < len(engines)):
|
||||
vars.oaiurl = "https://api.openai.com/v1/engines/{0}/completions".format(engines[int(engine)]["id"])
|
||||
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(engines[int(engine)]["id"])
|
||||
args.configname = args.configname + "/" + en["id"]
|
||||
engselected = True
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
@ -1020,7 +1030,7 @@ socketio = SocketIO(app, async_method="eventlet")
|
|||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
# Start transformers and create pipeline
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
if(not vars.noai):
|
||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
|
Loading…
Reference in New Issue