Final touches

This commit is contained in:
somebody
2022-09-24 12:54:20 -05:00
parent 3a727bc381
commit 5cdeb79752

View File

@@ -4826,9 +4826,6 @@ def calcsubmit(txt):
# Send it! # Send it!
ikrequest(subtxt) ikrequest(subtxt)
def __debug(*args):
print("[DBG] ", *args)
def core_generate(text: list, min: int, max: int, found_entries: set): def core_generate(text: list, min: int, max: int, found_entries: set):
# This generation function is tangled with koboldai_vars intentionally. It # This generation function is tangled with koboldai_vars intentionally. It
# is meant for the story and nothing else. # is meant for the story and nothing else.
@@ -4866,13 +4863,11 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
koboldai_vars._prompt = koboldai_vars.prompt koboldai_vars._prompt = koboldai_vars.prompt
__debug("generate core", text)
with torch.no_grad(): with torch.no_grad():
already_generated = 0 already_generated = 0
numseqs = koboldai_vars.numseqs numseqs = koboldai_vars.numseqs
while True: while True:
__debug("generate loop start", text)
# The reason this is a loop is due to how Dynamic WI works. We # The reason this is a loop is due to how Dynamic WI works. We
# cannot simply add the WI to the context mid-generation, so we # cannot simply add the WI to the context mid-generation, so we
# stop early, and then insert WI, then continue generating. That # stop early, and then insert WI, then continue generating. That
@@ -4888,15 +4883,12 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
bypass_hf_maxlength=True, bypass_hf_maxlength=True,
) )
__debug("generate result", result.__dict__)
genout = result.encoded genout = result.encoded
already_generated += len(genout[0]) - 1 # - len(gen_in[0]) already_generated += len(genout[0]) - 1
assert already_generated <= koboldai_vars.genamt assert already_generated <= koboldai_vars.genamt
if result.is_whole_generation: if result.is_whole_generation:
__debug("Outa here")
break break
# Generation stopped; why? # Generation stopped; why?
@@ -4953,13 +4945,9 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
) )
genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1) genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1)
assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length
# diff = genout.shape[-1] - gen_in.shape[-1]
# minimum += diff
# maximum += diff
gen_in = genout gen_in = genout
numseqs = 1 numseqs = 1
__debug("final out", genout, "already_gen", already_generated)
return genout, already_generated return genout, already_generated
class GenerationResult: class GenerationResult:
@@ -4975,10 +4963,8 @@ class GenerationResult:
# Controls if we should trim output by prompt length # Controls if we should trim output by prompt length
output_includes_prompt: bool = False, output_includes_prompt: bool = False,
): ):
# Shave prompt off of encoded response. Decoded does not return prompt. # Shave prompt off of encoded response when needed (HF). Decoded does
# TODO: Does MTJ generation shave this off automatically? Test it! # not return prompt.
print("shape", out_batches.shape)
if output_includes_prompt: if output_includes_prompt:
self.encoded = out_batches[:, len(prompt) - 1:] self.encoded = out_batches[:, len(prompt) - 1:]
else: else:
@@ -5065,7 +5051,6 @@ def tpu_raw_generate(
# Mostly lifted from apiactionsubmit_tpumtjgenerate # Mostly lifted from apiactionsubmit_tpumtjgenerate
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
__debug("we are generating with", prompt_tokens, "batch", batch_count, "soft tokens", soft_tokens)
genout = tpool.execute( genout = tpool.execute(
tpu_mtj_backend.infer_static, tpu_mtj_backend.infer_static,