diff --git a/aiserver.py b/aiserver.py index 2ebe148e..5dccb291 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5528,6 +5528,10 @@ class GenerationResult: # Controls if we should trim output by prompt length output_includes_prompt: bool = False, + + # Lazy filter to cut off extra lines where we can't manipulate + # probabilities + single_line: bool = False, ): # Shave prompt off of encoded response when needed (HF). Decoded does # not return prompt. @@ -5541,6 +5545,10 @@ class GenerationResult: self.decoded = [utils.decodenewlines(tokenizer.decode(enc)) for enc in self.encoded] + if single_line: + self.decoded = [x.split("\n", 1)[0] for x in self.decoded] + self.encoded = tokenizer(self.decoded).input_ids + class GenerationSettings: def __init__(self, **overrides) -> None: for setting in [ @@ -5623,7 +5631,7 @@ def raw_generate( gen_settings=gen_settings ) result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True ) elif koboldai_vars.model in model_functions: batch_encoded = model_functions[koboldai_vars.model]( @@ -5633,7 +5641,7 @@ def raw_generate( gen_settings=gen_settings ) result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True ) elif koboldai_vars.model.startswith("RWKV"): batch_encoded = rwkv_raw_generate( @@ -5643,7 +5651,7 @@ def raw_generate( gen_settings=gen_settings ) result = GenerationResult( - out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True + out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True, single_line=True ) else: # Torch HF