Fix for newlines on non-torch

Kinda sucks for now
This commit is contained in:
somebody
2022-12-12 22:58:44 -06:00
parent fb39dccd37
commit 4a4b021132

View File

@@ -5528,6 +5528,10 @@ class GenerationResult:
# Controls if we should trim output by prompt length
output_includes_prompt: bool = False,
# Lazy filter to cut off extra lines where we can't manipulate
# probabilities
single_line: bool = False,
):
# Shave prompt off of encoded response when needed (HF). Decoded does
# not return prompt.
@@ -5541,6 +5545,10 @@ class GenerationResult:
self.decoded = [utils.decodenewlines(tokenizer.decode(enc)) for enc in self.encoded]
if single_line:
self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
self.encoded = tokenizer(self.decoded).input_ids
class GenerationSettings:
def __init__(self, **overrides) -> None:
for setting in [
@@ -5623,7 +5631,7 @@ def raw_generate(
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True
)
elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model](
@@ -5633,7 +5641,7 @@ def raw_generate(
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
@@ -5643,7 +5651,7 @@ def raw_generate(
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True, single_line=True
)
else:
# Torch HF