mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
MTJ Fix for trim
This commit is contained in:
36
aiserver.py
36
aiserver.py
@@ -4766,24 +4766,7 @@ def calcsubmit(txt):
|
||||
else:
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||
|
||||
# if koboldai_vars.model not in (
|
||||
# "Colab",
|
||||
# "API",
|
||||
# "CLUSTER",
|
||||
# # "TPUMeshTransformerGPTJ",
|
||||
# # "TPUMeshTransformerGPTNeoX"
|
||||
# ):
|
||||
generate(subtxt, min, max, found_entries)
|
||||
# elif koboldai_vars.model == "Colab":
|
||||
# sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
# elif koboldai_vars.model == "API":
|
||||
# sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
# elif koboldai_vars.model == "CLUSTER":
|
||||
# sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
# elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
# tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
# else:
|
||||
# print(":(", koboldai_vars.model)
|
||||
|
||||
# For InferKit web API
|
||||
else:
|
||||
@@ -4984,12 +4967,20 @@ class GenerationResult:
|
||||
|
||||
# Controls if generate() does it's looping thing. This should only be
|
||||
# done for HF models that use that StoppingCondition
|
||||
is_whole_generation: bool
|
||||
is_whole_generation: bool,
|
||||
|
||||
# Controls if we should trim output by prompt length
|
||||
output_includes_prompt: bool = False,
|
||||
):
|
||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||
print("shape", out_batches.shape)
|
||||
__debug("shape", out_batches.shape)
|
||||
|
||||
if output_includes_prompt:
|
||||
self.encoded = out_batches[:, len(prompt) - 1:]
|
||||
else:
|
||||
self.encoded = out_batches
|
||||
|
||||
self.prompt = prompt
|
||||
self.is_whole_generation = is_whole_generation
|
||||
|
||||
@@ -5042,10 +5033,13 @@ def raw_generate(
|
||||
max_new=max_length if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count
|
||||
batch_count=batch_count,
|
||||
)
|
||||
return GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
||||
out_batches=batch_encoded,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
|
||||
def tpu_raw_generate(
|
||||
|
Reference in New Issue
Block a user