This commit is contained in:
somebody
2022-12-18 16:05:03 -06:00
parent bf82f257d1
commit 338e8b4049

View File

@@ -5526,7 +5526,7 @@ class GenerationResult:
if single_line: if single_line:
self.decoded = [x.split("\n", 1)[0] for x in self.decoded] self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
self.encoded = tokenizer(self.decoded).input_ids self.encoded = np.array(tokenizer(self.decoded).input_ids)
class GenerationSettings: class GenerationSettings:
def __init__(self, **overrides) -> None: def __init__(self, **overrides) -> None:
@@ -5610,7 +5610,7 @@ def raw_generate(
gen_settings=gen_settings gen_settings=gen_settings
) )
result = GenerationResult( result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=single_line
) )
elif koboldai_vars.model in model_functions: elif koboldai_vars.model in model_functions:
batch_encoded = model_functions[koboldai_vars.model]( batch_encoded = model_functions[koboldai_vars.model](
@@ -5620,7 +5620,7 @@ def raw_generate(
gen_settings=gen_settings gen_settings=gen_settings
) )
result = GenerationResult( result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=True out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, single_line=single_line
) )
elif koboldai_vars.model.startswith("RWKV"): elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate( batch_encoded = rwkv_raw_generate(
@@ -5630,7 +5630,7 @@ def raw_generate(
gen_settings=gen_settings gen_settings=gen_settings
) )
result = GenerationResult( result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True, single_line=True out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True, single_line=single_line
) )
else: else:
# Torch HF # Torch HF
@@ -5690,7 +5690,6 @@ def tpu_raw_generate(
soft_tokens=soft_tokens, soft_tokens=soft_tokens,
sampler_order=gen_settings.sampler_order, sampler_order=gen_settings.sampler_order,
) )
genout = np.array(genout) genout = np.array(genout)
return genout return genout
@@ -6974,7 +6973,6 @@ def anotesubmit(data, template=""):
if(koboldai_vars.authornotetemplate != template): if(koboldai_vars.authornotetemplate != template):
koboldai_vars.setauthornotetemplate = template koboldai_vars.setauthornotetemplate = template
print("anotesubmit")
settingschanged() settingschanged()
koboldai_vars.authornotetemplate = template koboldai_vars.authornotetemplate = template