Basic GooseAI Support
This commit is contained in:
parent
f1b0ea711e
commit
3ddc9647eb
22
aiserver.py
22
aiserver.py
|
@ -126,6 +126,7 @@ xglmlist = [
|
||||||
]
|
]
|
||||||
|
|
||||||
apilist = [
|
apilist = [
|
||||||
|
["GooseAI API (requires API key)", "GooseAI", ""],
|
||||||
["OpenAI API (requires API key)", "OAI", ""],
|
["OpenAI API (requires API key)", "OAI", ""],
|
||||||
["InferKit API (requires API key)", "InferKit", ""],
|
["InferKit API (requires API key)", "InferKit", ""],
|
||||||
["KoboldAI Server API (Old Google Colab)", "Colab", ""],
|
["KoboldAI Server API (Old Google Colab)", "Colab", ""],
|
||||||
|
@ -787,7 +788,7 @@ else:
|
||||||
getModelSelection(mainmenu)
|
getModelSelection(mainmenu)
|
||||||
|
|
||||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||||
if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
# Test for GPU support
|
# Test for GPU support
|
||||||
import torch
|
import torch
|
||||||
|
@ -827,7 +828,7 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||||
vars.model_type = "gpt_neo"
|
vars.model_type = "gpt_neo"
|
||||||
|
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
loadsettings()
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
|
@ -927,11 +928,19 @@ if(vars.model == "InferKit"):
|
||||||
finally:
|
finally:
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
# Swap OAI Server if GooseAI was selected
|
||||||
|
if(vars.model == "GooseAI"):
|
||||||
|
vars.oaiengines = "https://api.goose.ai/v1/engines"
|
||||||
|
vars.model = "OAI"
|
||||||
|
args.configname = "GooseAI"
|
||||||
|
|
||||||
# Ask for API key if OpenAI was selected
|
# Ask for API key if OpenAI was selected
|
||||||
if(vars.model == "OAI"):
|
if(vars.model == "OAI"):
|
||||||
|
if not args.configname:
|
||||||
|
args.configname = "OAI"
|
||||||
if(not path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
if(not path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||||
# If the client settings file doesn't exist, create it
|
# If the client settings file doesn't exist, create it
|
||||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
print("{0}Please enter your API key:{1}\n".format(colors.CYAN, colors.END))
|
||||||
vars.oaiapikey = input("Key> ")
|
vars.oaiapikey = input("Key> ")
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
os.makedirs('settings', exist_ok=True)
|
os.makedirs('settings', exist_ok=True)
|
||||||
|
@ -952,7 +961,7 @@ if(vars.model == "OAI"):
|
||||||
file.close()
|
file.close()
|
||||||
else:
|
else:
|
||||||
# Get API key, add it to settings object, and write it to disk
|
# Get API key, add it to settings object, and write it to disk
|
||||||
print("{0}Please enter your OpenAI API key:{1}\n".format(colors.CYAN, colors.END))
|
print("{0}Please enter your API key:{1}\n".format(colors.CYAN, colors.END))
|
||||||
vars.oaiapikey = input("Key> ")
|
vars.oaiapikey = input("Key> ")
|
||||||
js["oaiapikey"] = vars.oaiapikey
|
js["oaiapikey"] = vars.oaiapikey
|
||||||
# Write API key to file
|
# Write API key to file
|
||||||
|
@ -985,7 +994,8 @@ if(vars.model == "OAI"):
|
||||||
while(engselected == False):
|
while(engselected == False):
|
||||||
engine = input("Engine #> ")
|
engine = input("Engine #> ")
|
||||||
if(engine.isnumeric() and int(engine) < len(engines)):
|
if(engine.isnumeric() and int(engine) < len(engines)):
|
||||||
vars.oaiurl = "https://api.openai.com/v1/engines/{0}/completions".format(engines[int(engine)]["id"])
|
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(engines[int(engine)]["id"])
|
||||||
|
args.configname = args.configname + "/" + en["id"]
|
||||||
engselected = True
|
engselected = True
|
||||||
else:
|
else:
|
||||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||||
|
@ -1020,7 +1030,7 @@ socketio = SocketIO(app, async_method="eventlet")
|
||||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||||
|
|
||||||
# Start transformers and create pipeline
|
# Start transformers and create pipeline
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
Loading…
Reference in New Issue