MTJ Fix for trim

This commit is contained in:
somebody
2022-09-22 21:57:01 -05:00
parent ffbe50920e
commit ca356d4d6f

View File

@@ -4766,24 +4766,7 @@ def calcsubmit(txt):
else: else:
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
# if koboldai_vars.model not in (
# "Colab",
# "API",
# "CLUSTER",
# # "TPUMeshTransformerGPTJ",
# # "TPUMeshTransformerGPTNeoX"
# ):
generate(subtxt, min, max, found_entries) generate(subtxt, min, max, found_entries)
# elif koboldai_vars.model == "Colab":
# sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
# elif koboldai_vars.model == "API":
# sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
# elif koboldai_vars.model == "CLUSTER":
# sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
# elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
# tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
# else:
# print(":(", koboldai_vars.model)
# For InferKit web API # For InferKit web API
else: else:
@@ -4984,12 +4967,20 @@ class GenerationResult:
# Controls if generate() does it's looping thing. This should only be # Controls if generate() does it's looping thing. This should only be
# done for HF models that use that StoppingCondition # done for HF models that use that StoppingCondition
is_whole_generation: bool is_whole_generation: bool,
# Controls if we should trim output by prompt length
output_includes_prompt: bool = False,
): ):
# Shave prompt off of encoded response. Decoded does not return prompt. # Shave prompt off of encoded response. Decoded does not return prompt.
# TODO: Does MTJ generation shave this off automatically? Test it! # TODO: Does MTJ generation shave this off automatically? Test it!
print("shape", out_batches.shape) __debug("shape", out_batches.shape)
self.encoded = out_batches[:, len(prompt) - 1:]
if output_includes_prompt:
self.encoded = out_batches[:, len(prompt) - 1:]
else:
self.encoded = out_batches
self.prompt = prompt self.prompt = prompt
self.is_whole_generation = is_whole_generation self.is_whole_generation = is_whole_generation
@@ -5042,10 +5033,13 @@ def raw_generate(
max_new=max_length if not bypass_hf_maxlength else int(2e9), max_new=max_length if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming, do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi, do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count batch_count=batch_count,
) )
return GenerationResult( return GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
) )
def tpu_raw_generate( def tpu_raw_generate(