mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Eternal gen work
This commit is contained in:
69
aiserver.py
69
aiserver.py
@@ -4534,31 +4534,26 @@ def calcsubmit(txt):
|
|||||||
print("Using Alt Gen")
|
print("Using Alt Gen")
|
||||||
else:
|
else:
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||||
if(actionlen == 0):
|
|
||||||
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
if not koboldai_vars.use_colab_tpu and koboldai_vars.model not in (
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
"Colab",
|
||||||
elif(koboldai_vars.model == "Colab"):
|
"API",
|
||||||
|
"CLUSTER",
|
||||||
|
"OAI",
|
||||||
|
"TPUMeshTransformerGPTJ",
|
||||||
|
"TPUMeshTransformerGPTNeoX"
|
||||||
|
):
|
||||||
|
legacy_generate(subtxt, min, max)
|
||||||
|
# generate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
elif koboldai_vars.model == "Colab":
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(koboldai_vars.model == "API"):
|
elif koboldai_vars.model == "API":
|
||||||
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(koboldai_vars.model == "CLUSTER"):
|
elif koboldai_vars.model == "CLUSTER":
|
||||||
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(koboldai_vars.model == "OAI"):
|
elif koboldai_vars.model == "OAI":
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
|
||||||
else:
|
|
||||||
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
|
||||||
elif(koboldai_vars.model == "Colab"):
|
|
||||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
elif(koboldai_vars.model == "API"):
|
|
||||||
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
elif(koboldai_vars.model == "CLUSTER"):
|
|
||||||
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
elif(koboldai_vars.model == "OAI"):
|
|
||||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
elif(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
|
||||||
# For InferKit web API
|
# For InferKit web API
|
||||||
@@ -4617,18 +4612,28 @@ def calcsubmit(txt):
|
|||||||
# Send it!
|
# Send it!
|
||||||
ikrequest(subtxt)
|
ikrequest(subtxt)
|
||||||
|
|
||||||
def legacy_generate(text):
|
def legacy_generate(text: Union[str, list], min: int, max: int):
|
||||||
# Architected after oairequest
|
# Architected after oairequest
|
||||||
|
|
||||||
koboldai_vars.lastctx = text
|
koboldai_vars.lastctx = text
|
||||||
|
|
||||||
outputs = raw_generate(text)
|
print("Pregen")
|
||||||
|
print(koboldai_vars.max_length)
|
||||||
|
outputs = raw_generate(
|
||||||
|
text,
|
||||||
|
max_length=koboldai_vars.genamt,
|
||||||
|
do_streaming=True
|
||||||
|
)
|
||||||
|
print(f"postgen: {outputs}")
|
||||||
|
|
||||||
# Lua bridge, genmod
|
# Lua bridge, genmod
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
koboldai_vars.lua_koboldbridge.outputs[i + 1] = output
|
||||||
|
|
||||||
|
print("post lua")
|
||||||
|
|
||||||
execute_genmod()
|
execute_genmod()
|
||||||
|
print("post genmod")
|
||||||
|
|
||||||
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
if koboldai_vars.lua_koboldbridge.regeneration_required:
|
||||||
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
@@ -4640,9 +4645,13 @@ def legacy_generate(text):
|
|||||||
else:
|
else:
|
||||||
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
|
genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs]
|
||||||
|
|
||||||
|
print("post assign genout")
|
||||||
|
|
||||||
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout])
|
||||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||||
|
|
||||||
|
print("post genout assign")
|
||||||
|
|
||||||
if len(genout) == 1:
|
if len(genout) == 1:
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
else:
|
else:
|
||||||
@@ -4651,7 +4660,9 @@ def legacy_generate(text):
|
|||||||
genresult(genout[restart_seq - 1]["generated_text"])
|
genresult(genout[restart_seq - 1]["generated_text"])
|
||||||
else:
|
else:
|
||||||
genselect(genout)
|
genselect(genout)
|
||||||
|
print("post whatever that is")
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
print("post busy")
|
||||||
|
|
||||||
def raw_generate(
|
def raw_generate(
|
||||||
# prompt is either a string (text) or a list (token ids)
|
# prompt is either a string (text) or a list (token ids)
|
||||||
@@ -4697,7 +4708,7 @@ def raw_generate(
|
|||||||
return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out]
|
return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out]
|
||||||
|
|
||||||
def tpu_raw_generate(
|
def tpu_raw_generate(
|
||||||
prompt_tokens: list[int],
|
prompt_tokens: List[int],
|
||||||
max_length: int,
|
max_length: int,
|
||||||
batch_count: int,
|
batch_count: int,
|
||||||
):
|
):
|
||||||
@@ -4725,7 +4736,7 @@ def tpu_raw_generate(
|
|||||||
return genout
|
return genout
|
||||||
|
|
||||||
def torch_raw_generate(
|
def torch_raw_generate(
|
||||||
prompt_tokens: list[int],
|
prompt_tokens: List[int],
|
||||||
max_length: int,
|
max_length: int,
|
||||||
|
|
||||||
do_streaming: bool = False,
|
do_streaming: bool = False,
|
||||||
@@ -4748,7 +4759,10 @@ def torch_raw_generate(
|
|||||||
device = breakmodel.primary_device
|
device = breakmodel.primary_device
|
||||||
gen_in = gen_in.to(device)
|
gen_in = gen_in.to(device)
|
||||||
|
|
||||||
|
print("okay...")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
print(f"in {max_length}")
|
||||||
genout = generator(
|
genout = generator(
|
||||||
gen_in,
|
gen_in,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
@@ -4758,8 +4772,10 @@ def torch_raw_generate(
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_return_sequences=batch_count,
|
num_return_sequences=batch_count,
|
||||||
)
|
)
|
||||||
|
print("out")
|
||||||
|
print("wtf")
|
||||||
|
|
||||||
return genout[0]
|
return genout
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Send text to generator and deal with output
|
# Send text to generator and deal with output
|
||||||
@@ -7984,7 +8000,6 @@ def UI_2_generate_raw():
|
|||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
return Response(json.dumps({"error": str(e)}), status=500)
|
return Response(json.dumps({"error": str(e)}), status=500)
|
||||||
|
|
||||||
print(f"{out=}")
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user