From ca356d4d6fc3c4ac171d6beba84d3dad76bfe3c2 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 22 Sep 2022 21:57:01 -0500 Subject: [PATCH] MTJ Fix for trim --- aiserver.py | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/aiserver.py b/aiserver.py index cafc5f9f..90575eac 100644 --- a/aiserver.py +++ b/aiserver.py @@ -4766,24 +4766,7 @@ def calcsubmit(txt): else: subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) - # if koboldai_vars.model not in ( - # "Colab", - # "API", - # "CLUSTER", - # # "TPUMeshTransformerGPTJ", - # # "TPUMeshTransformerGPTNeoX" - # ): generate(subtxt, min, max, found_entries) - # elif koboldai_vars.model == "Colab": - # sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - # elif koboldai_vars.model == "API": - # sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - # elif koboldai_vars.model == "CLUSTER": - # sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) - # elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): - # tpumtjgenerate(subtxt, min, max, found_entries=found_entries) - # else: - # print(":(", koboldai_vars.model) # For InferKit web API else: @@ -4984,12 +4967,20 @@ class GenerationResult: # Controls if generate() does it's looping thing. This should only be # done for HF models that use that StoppingCondition - is_whole_generation: bool + is_whole_generation: bool, + + # Controls if we should trim output by prompt length + output_includes_prompt: bool = False, ): # Shave prompt off of encoded response. Decoded does not return prompt. # TODO: Does MTJ generation shave this off automatically? Test it! - print("shape", out_batches.shape) - self.encoded = out_batches[:, len(prompt) - 1:] + __debug("shape", out_batches.shape) + + if output_includes_prompt: + self.encoded = out_batches[:, len(prompt) - 1:] + else: + self.encoded = out_batches + self.prompt = prompt self.is_whole_generation = is_whole_generation @@ -5042,10 +5033,13 @@ def raw_generate( max_new=max_length if not bypass_hf_maxlength else int(2e9), do_streaming=do_streaming, do_dynamic_wi=do_dynamic_wi, - batch_count=batch_count + batch_count=batch_count, ) return GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False + out_batches=batch_encoded, + prompt=prompt_tokens, + is_whole_generation=False, + output_includes_prompt=True, ) def tpu_raw_generate(