mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Added alt text gen
This commit is contained in:
36
aiserver.py
36
aiserver.py
@@ -2881,13 +2881,17 @@ def lua_compute_context(submission, entries, folders, kwargs):
|
||||
force_use_txt=True,
|
||||
scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
|
||||
)
|
||||
txt, _, _ = calcsubmitbudget(
|
||||
len(actions),
|
||||
winfo,
|
||||
mem,
|
||||
anotetxt,
|
||||
actions,
|
||||
)
|
||||
if koboldai_vars.alt_gen:
|
||||
txt, _, _ = koboldai_vars.calc_ai_text()
|
||||
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(
|
||||
len(actions),
|
||||
winfo,
|
||||
mem,
|
||||
anotetxt,
|
||||
actions,
|
||||
)
|
||||
return utils.decodenewlines(tokenizer.decode(txt))
|
||||
|
||||
#==================================================================#
|
||||
@@ -4458,7 +4462,11 @@ def calcsubmit(txt):
|
||||
|
||||
# For all transformers models
|
||||
if(koboldai_vars.model != "InferKit"):
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
if koboldai_vars.alt_gen:
|
||||
subtxt, min, max = koboldai_vars.calc_ai_text(submitted_text=txt)
|
||||
print("Using Alt Gen: {}".format(tokenizer.decode(subtxt)))
|
||||
else:
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
if(actionlen == 0):
|
||||
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["Colab", "API", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
@@ -4601,7 +4609,11 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
|
||||
if koboldai_vars.alt_gen:
|
||||
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
|
||||
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
@@ -4998,7 +5010,11 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
txt = utils.decodenewlines(tokenizer.decode(past[i]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
|
||||
if koboldai_vars.alt_gen:
|
||||
txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt)
|
||||
print("Using Alt Gen: {}".format(tokenizer.decode(txt)))
|
||||
else:
|
||||
txt, _, _ = calcsubmitbudget(len(koboldai_vars._actions), winfo, mem, anotetxt, koboldai_vars._actions, submission=txt)
|
||||
encoded.append(np.array(txt, dtype=np.uint32))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
|
||||
|
Reference in New Issue
Block a user