Allow KoboldAI to use its own API to generate text
This commit is contained in:
parent
4eff7bf3ba
commit
6853625570
122
aiserver.py
122
aiserver.py
|
@ -214,7 +214,8 @@ model_menu = {
|
||||||
["GooseAI API (requires API key)", "GooseAI", "", False],
|
["GooseAI API (requires API key)", "GooseAI", "", False],
|
||||||
["OpenAI API (requires API key)", "OAI", "", False],
|
["OpenAI API (requires API key)", "OAI", "", False],
|
||||||
["InferKit API (requires API key)", "InferKit", "", False],
|
["InferKit API (requires API key)", "InferKit", "", False],
|
||||||
["KoboldAI Server API (Old Google Colab)", "Colab", "", False],
|
# ["KoboldAI Server API (Old Google Colab)", "Colab", "", False],
|
||||||
|
["KoboldAI API", "API", "", False],
|
||||||
["Return to Main Menu", "mainmenu", "", True],
|
["Return to Main Menu", "mainmenu", "", True],
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -1259,6 +1260,7 @@ def general_startup(override_args=None):
|
||||||
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
||||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||||
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
|
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
|
||||||
|
parser.add_argument("--tokenizer", type=str, help="When using the \"KoboldAI API\" backend option, this controls the tokenizer to use. This can be set to a Hugging Face model ID or the path to a folder under \"models\" in the KoboldAI folder.")
|
||||||
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
|
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
|
||||||
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
|
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
|
||||||
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
|
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
|
||||||
|
@ -1432,7 +1434,7 @@ def get_model_info(model, directory=""):
|
||||||
|
|
||||||
|
|
||||||
def get_layer_count(model, directory=""):
|
def get_layer_count(model, directory=""):
|
||||||
if(model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(vars.model == "GPT2Custom"):
|
if(vars.model == "GPT2Custom"):
|
||||||
model_config = open(vars.custmodpth + "/config.json", "r")
|
model_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
# Get the model_type from the config or assume a model type if it isn't present
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
|
@ -1973,7 +1975,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
|
|
||||||
|
|
||||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||||
if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(vars.model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
# Test for GPU support
|
# Test for GPU support
|
||||||
|
|
||||||
|
@ -2012,7 +2014,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||||
vars.model_type = "gpt_neo"
|
vars.model_type = "gpt_neo"
|
||||||
|
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
loadmodelsettings()
|
loadmodelsettings()
|
||||||
loadsettings()
|
loadsettings()
|
||||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||||
|
@ -2087,7 +2089,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
vars.noai = True
|
vars.noai = True
|
||||||
|
|
||||||
# Start transformers and create pipeline
|
# Start transformers and create pipeline
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "API", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
for m in ("GPTJModel", "XGLMModel"):
|
for m in ("GPTJModel", "XGLMModel"):
|
||||||
|
@ -2542,6 +2544,26 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", revision=vars.revision, cache_dir="cache")
|
||||||
loadsettings()
|
loadsettings()
|
||||||
|
elif(vars.model == "API"):
|
||||||
|
tokenizer_id = getattr(args, "tokenizer", None)
|
||||||
|
if tokenizer_id is None:
|
||||||
|
tokenizer_id = "EleutherAI/gpt-neo-2.7B"
|
||||||
|
if(os.path.isdir(tokenizer_id)):
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=vars.revision, cache_dir="cache")
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=vars.revision, cache_dir="cache", use_fast=False)
|
||||||
|
elif(os.path.isdir("models/{}".format(args.tokenizer.replace('/', '_')))):
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(tokenizer_id.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(tokenizer_id.replace('/', '_')), revision=vars.revision, cache_dir="cache", use_fast=False)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=vars.revision, cache_dir="cache")
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=vars.revision, cache_dir="cache", use_fast=False)
|
||||||
|
loadsettings()
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
|
@ -3179,7 +3201,7 @@ def lua_set_chunk(k, v):
|
||||||
def lua_get_modeltype():
|
def lua_get_modeltype():
|
||||||
if(vars.noai):
|
if(vars.noai):
|
||||||
return "readonly"
|
return "readonly"
|
||||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
if(vars.model in ("Colab", "API", "OAI", "InferKit")):
|
||||||
return "api"
|
return "api"
|
||||||
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
|
||||||
hidden_size = get_hidden_size_from_model(model)
|
hidden_size = get_hidden_size_from_model(model)
|
||||||
|
@ -3208,7 +3230,7 @@ def lua_get_modeltype():
|
||||||
def lua_get_modelbackend():
|
def lua_get_modelbackend():
|
||||||
if(vars.noai):
|
if(vars.noai):
|
||||||
return "readonly"
|
return "readonly"
|
||||||
if(vars.model in ("Colab", "OAI", "InferKit")):
|
if(vars.model in ("Colab", "API", "OAI", "InferKit")):
|
||||||
return "api"
|
return "api"
|
||||||
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
return "mtj"
|
return "mtj"
|
||||||
|
@ -4113,6 +4135,8 @@ def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum):
|
||||||
def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=False, use_authors_note=False):
|
def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=False, use_authors_note=False):
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
raise NotImplementedError("API generation is not supported in old Colab API mode.")
|
raise NotImplementedError("API generation is not supported in old Colab API mode.")
|
||||||
|
elif(vars.model == "API"):
|
||||||
|
raise NotImplementedError("API generation is not supported in API mode.")
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.")
|
raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.")
|
||||||
elif(vars.model == "ReadOnly"):
|
elif(vars.model == "ReadOnly"):
|
||||||
|
@ -4161,7 +4185,7 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
|
||||||
minimum = len(tokens) + 1
|
minimum = len(tokens) + 1
|
||||||
maximum = len(tokens) + vars.genamt
|
maximum = len(tokens) + vars.genamt
|
||||||
|
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
genout = apiactionsubmit_generate(tokens, minimum, maximum)
|
genout = apiactionsubmit_generate(tokens, minimum, maximum)
|
||||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
||||||
|
@ -4402,19 +4426,23 @@ def calcsubmit(txt):
|
||||||
if(vars.model != "InferKit"):
|
if(vars.model != "InferKit"):
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
||||||
if(actionlen == 0):
|
if(actionlen == 0):
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
|
elif(vars.model == "API"):
|
||||||
|
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
else:
|
else:
|
||||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
|
elif(vars.model == "API"):
|
||||||
|
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
|
@ -4820,6 +4848,80 @@ def sendtocolab(txt, min, max):
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Send transformers-style request to KoboldAI API
|
||||||
|
#==================================================================#
|
||||||
|
def sendtoapi(txt, min, max):
|
||||||
|
# Log request to console
|
||||||
|
if not vars.quiet:
|
||||||
|
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
vars.lastctx = txt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
reqdata = {
|
||||||
|
'prompt': txt,
|
||||||
|
'max_length': max - min + 1,
|
||||||
|
'rep_pen': vars.rep_pen,
|
||||||
|
'rep_pen_slope': vars.rep_pen_slope,
|
||||||
|
'rep_pen_range': vars.rep_pen_range,
|
||||||
|
'temperature': vars.temp,
|
||||||
|
'top_p': vars.top_p,
|
||||||
|
'top_k': vars.top_k,
|
||||||
|
'top_a': vars.top_a,
|
||||||
|
'tfs': vars.tfs,
|
||||||
|
'typical': vars.typical,
|
||||||
|
'n': vars.numseqs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
while True:
|
||||||
|
req = requests.post(
|
||||||
|
vars.colaburl.replace("/request", "/api/v1/generate"),
|
||||||
|
json=reqdata,
|
||||||
|
)
|
||||||
|
if(req.status_code == 503): # Server is currently generating something else so poll until it's our turn
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
js = req.json()
|
||||||
|
if(req.status_code != 200):
|
||||||
|
errmsg = "KoboldAI API Error: Failed to get a reply from the server. Please check the console."
|
||||||
|
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
|
||||||
|
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
|
||||||
|
genout = [obj["text"] for obj in js["results"]]
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.outputs[i+1] = genout[i]
|
||||||
|
|
||||||
|
execute_outmod()
|
||||||
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
genout = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
genout.append(vars.lua_koboldbridge.outputs[i+1])
|
||||||
|
assert type(genout[-1]) is str
|
||||||
|
|
||||||
|
if(len(genout) == 1):
|
||||||
|
genresult(genout[0])
|
||||||
|
else:
|
||||||
|
# Convert torch output format to transformers
|
||||||
|
seqs = []
|
||||||
|
for seq in genout:
|
||||||
|
seqs.append({"generated_text": seq})
|
||||||
|
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
|
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||||
|
else:
|
||||||
|
genselect(genout)
|
||||||
|
|
||||||
|
set_aibusy(0)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send text to TPU mesh transformer backend
|
# Send text to TPU mesh transformer backend
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
Loading…
Reference in New Issue