mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
MTJ Fix for trim
This commit is contained in:
38
aiserver.py
38
aiserver.py
@@ -4766,24 +4766,7 @@ def calcsubmit(txt):
|
|||||||
else:
|
else:
|
||||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, koboldai_vars.actions, submission=txt)
|
||||||
|
|
||||||
# if koboldai_vars.model not in (
|
|
||||||
# "Colab",
|
|
||||||
# "API",
|
|
||||||
# "CLUSTER",
|
|
||||||
# # "TPUMeshTransformerGPTJ",
|
|
||||||
# # "TPUMeshTransformerGPTNeoX"
|
|
||||||
# ):
|
|
||||||
generate(subtxt, min, max, found_entries)
|
generate(subtxt, min, max, found_entries)
|
||||||
# elif koboldai_vars.model == "Colab":
|
|
||||||
# sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
# elif koboldai_vars.model == "API":
|
|
||||||
# sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
# elif koboldai_vars.model == "CLUSTER":
|
|
||||||
# sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
|
||||||
# elif koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
|
||||||
# tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
|
||||||
# else:
|
|
||||||
# print(":(", koboldai_vars.model)
|
|
||||||
|
|
||||||
# For InferKit web API
|
# For InferKit web API
|
||||||
else:
|
else:
|
||||||
@@ -4984,12 +4967,20 @@ class GenerationResult:
|
|||||||
|
|
||||||
# Controls if generate() does it's looping thing. This should only be
|
# Controls if generate() does it's looping thing. This should only be
|
||||||
# done for HF models that use that StoppingCondition
|
# done for HF models that use that StoppingCondition
|
||||||
is_whole_generation: bool
|
is_whole_generation: bool,
|
||||||
|
|
||||||
|
# Controls if we should trim output by prompt length
|
||||||
|
output_includes_prompt: bool = False,
|
||||||
):
|
):
|
||||||
# Shave prompt off of encoded response. Decoded does not return prompt.
|
# Shave prompt off of encoded response. Decoded does not return prompt.
|
||||||
# TODO: Does MTJ generation shave this off automatically? Test it!
|
# TODO: Does MTJ generation shave this off automatically? Test it!
|
||||||
print("shape", out_batches.shape)
|
__debug("shape", out_batches.shape)
|
||||||
self.encoded = out_batches[:, len(prompt) - 1:]
|
|
||||||
|
if output_includes_prompt:
|
||||||
|
self.encoded = out_batches[:, len(prompt) - 1:]
|
||||||
|
else:
|
||||||
|
self.encoded = out_batches
|
||||||
|
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.is_whole_generation = is_whole_generation
|
self.is_whole_generation = is_whole_generation
|
||||||
|
|
||||||
@@ -5042,10 +5033,13 @@ def raw_generate(
|
|||||||
max_new=max_length if not bypass_hf_maxlength else int(2e9),
|
max_new=max_length if not bypass_hf_maxlength else int(2e9),
|
||||||
do_streaming=do_streaming,
|
do_streaming=do_streaming,
|
||||||
do_dynamic_wi=do_dynamic_wi,
|
do_dynamic_wi=do_dynamic_wi,
|
||||||
batch_count=batch_count
|
batch_count=batch_count,
|
||||||
)
|
)
|
||||||
return GenerationResult(
|
return GenerationResult(
|
||||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=False
|
out_batches=batch_encoded,
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=False,
|
||||||
|
output_includes_prompt=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tpu_raw_generate(
|
def tpu_raw_generate(
|
||||||
|
Reference in New Issue
Block a user