Eternal gen work

This commit is contained in:
somebody
2022-09-16 20:16:11 -05:00
parent f075ca9095
commit 8ffc084ef3

View File

@@ -4534,32 +4534,27 @@ def calcsubmit(txt):
print("Using Alt Gen")
else:
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
if(actionlen == 0):
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(koboldai_vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "API"):
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "CLUSTER"):
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else:
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
elif(koboldai_vars.model == "Colab"):
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "API"):
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "CLUSTER"):
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.model == "OAI"):
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
if not koboldai_vars.use_colab_tpu and koboldai_vars.model not in (
"Colab",
"API",
"CLUSTER",
"OAI",
"TPUMeshTransformerGPTJ",
"TPUMeshTransformerGPTNeoX"
):
legacy_generate(subtxt, min, max)
# generate(subtxt, min, max, found_entries=found_entries)
elif koboldai_vars.model == "Colab":
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif koboldai_vars.model == "API":
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif koboldai_vars.model == "CLUSTER":
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif koboldai_vars.model == "OAI":
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# For InferKit web API
else:
@@ -4617,18 +4612,28 @@ def calcsubmit(txt):
# Send it!
ikrequest(subtxt)
def legacy_generate(text):
def legacy_generate(text: Union[str, list], min: int, max: int):
# Architected after oairequest
koboldai_vars.lastctx = text
outputs = raw_generate(text)
print("Pregen")
print(koboldai_vars.max_length)
outputs = raw_generate(
text,
max_length=koboldai_vars.genamt,
do_streaming=True
)
print(f"postgen: {outputs}")
# Lua bridge, genmod
for i, output in enumerate(outputs):
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
print("post lua")
execute_genmod()
print("post genmod")
if koboldai_vars.lua_koboldbridge.regeneration_required:
koboldai_vars.lua_koboldbridge.regeneration_required = False
@@ -4639,10 +4644,14 @@ def legacy_generate(text):
assert isinstance(out, str)
else:
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
print("post assign genout")
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
print("post genout assign")
if len(genout) == 1:
genresult(genout[0]["generated_text"])
else:
@@ -4651,7 +4660,9 @@ def legacy_generate(text):
genresult(genout[restart_seq - 1]["generated_text"])
else:
genselect(genout)
print("post whatever that is")
set_aibusy(0)
print("post busy")
def raw_generate(
# prompt is either a string (text) or a list (token ids)
@@ -4697,7 +4708,7 @@ def raw_generate(
return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out]
def tpu_raw_generate(
prompt_tokens: list[int],
prompt_tokens: List[int],
max_length: int,
batch_count: int,
):
@@ -4725,7 +4736,7 @@ def tpu_raw_generate(
return genout
def torch_raw_generate(
prompt_tokens: list[int],
prompt_tokens: List[int],
max_length: int,
do_streaming: bool = False,
@@ -4748,7 +4759,10 @@ def torch_raw_generate(
device = breakmodel.primary_device
gen_in = gen_in.to(device)
print("okay...")
with torch.no_grad():
print(f"in {max_length}")
genout = generator(
gen_in,
do_sample=True,
@@ -4758,8 +4772,10 @@ def torch_raw_generate(
use_cache=True,
num_return_sequences=batch_count,
)
print("out")
print("wtf")
return genout[0]
return genout
#==================================================================#
# Send text to generator and deal with output
@@ -7984,7 +8000,6 @@ def UI_2_generate_raw():
except NotImplementedError as e:
return Response(json.dumps({"error": str(e)}), status=500)
print(f"{out=}")
return out
#==================================================================#