Add API for generated tokens and output text

This commit is contained in:
Gnome Ann
2021-12-12 19:27:20 -05:00
parent ceabd2ef7b
commit fbf3e7615b
3 changed files with 261 additions and 16 deletions

View File

@ -654,26 +654,37 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
self.regeneration_required = False
self.halt = False
scores_shape = scores.shape
scores_list = self.scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
if(vars.lua_koboldbridge.generated_cols != 0):
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
execute_genmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
scores = torch.tensor(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
device=scores.device,
dtype=scores.dtype,
)
assert scores.shape == scores_shape
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
self.regeneration_required = True
if(not vars.lua_koboldbridge.generating):
self.halt = True
return scores
def new_get_logits_warper(
@ -1748,9 +1759,19 @@ def actionsubmit(data, actionmode=0, force_submit=False):
calcsubmit(data) # Run the first action through the generator
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
else:
execute_outmod()
# Save this first action as the prompt
vars.prompt = data
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
refresh_story()
set_aibusy(0)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
@ -1776,6 +1797,16 @@ def actionsubmit(data, actionmode=0, force_submit=False):
else:
execute_outmod()
set_aibusy(0)
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
else:
genselect(genout)
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
#==================================================================#
@ -2076,6 +2107,12 @@ def generate(txt, minimum, maximum, found_entries=None):
break
assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs
if(already_generated != vars.lua_koboldbridge.generated_cols):
raise RuntimeError("WI scanning error")
for r in range(vars.numseqs):
for c in range(already_generated):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
@ -2111,12 +2148,20 @@ def generate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, e, colors.END))
set_aibusy(0)
return
# Need to manually strip and decode tokens if we're not using a pipeline
#already_generated = -(len(gen_in[0]) - len(tokens))
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item()
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
else:
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -2217,8 +2262,17 @@ def sendtocolab(txt, min, max):
else:
genout = js["seqs"]
for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = genout[i]
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append(vars.lua_koboldbridge.outputs[i+1])
assert type(genout[-1]) is str
if(len(genout) == 1):
genresult(genout[0])
else:
@ -2297,10 +2351,20 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, e, colors.END))
set_aibusy(0)
return
genout = [{"generated_text": txt} for txt in genout]
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = []
for i in range(vars.numseqs):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str
else:
genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout]
if(len(genout) == 1):
genresult(genout[0]["generated_text"])
@ -2888,7 +2952,15 @@ def ikrequest(txt):
# Deal with the response
if(req.status_code == 200):
genout = req.json()["data"]["text"]
vars.lua_koboldbridge.outputs[1] = genout
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = vars.lua_koboldbridge.outputs[1]
assert genout is str
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
vars.actions.append(genout)
update_story_chunk('last')
@ -2939,7 +3011,15 @@ def oairequest(txt, min, max):
# Deal with the response
if(req.status_code == 200):
genout = req.json()["choices"][0]["text"]
vars.lua_koboldbridge.outputs[1] = genout
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
vars.lua_koboldbridge.regeneration_required = False
genout = vars.lua_koboldbridge.outputs[1]
assert genout is str
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
vars.actions.append(genout)
update_story_chunk('last')