Added alt text gen

This commit is contained in:
ebolam
2022-08-18 12:54:20 -04:00
parent 01f06d5d0a
commit fdcf463c76
5 changed files with 122 additions and 43 deletions

View File

@@ -2881,13 +2881,17 @@ def lua_compute_context(submission, entries, folders, kwargs):
force_use_txt=True,
scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
)
txt, _, _ = calcsubmitbudget(
len(actions),
winfo,
mem,
anotetxt,
actions,
)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text()
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(
len(actions),
winfo,
mem,
anotetxt,
actions,
)
return utils.decodenewlines(tokenizer.decode(txt))
#==================================================================#
@@ -4458,7 +4462,11 @@ def calcsubmit(txt):
# For all transformers models
if(koboldai_vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
if koboldai_vars.alt_gen:
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(subtxt)))
else:
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
if(actionlen == 0):
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
generate(subtxt, min, max, found_entries=found_entries)
@@ -4601,7 +4609,11 @@ def _generate(txt, minimum, maximum, found_entries):
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
@@ -4998,7 +5010,11 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
txt = utils.decodenewlines(tokenizer.decode(past[i]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
if koboldai_vars.alt_gen:
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
else:
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))