Basic single line support

This commit is contained in:
somebody
2022-12-11 21:30:45 -06:00
parent c85252c13c
commit a60f2d3d53

View File

@@ -5580,8 +5580,10 @@ def raw_generate(
bypass_hf_maxlength: bool = False,
generation_settings: Optional[dict] = None,
is_core: bool = False,
single_line: bool = False,
found_entries: set = ()
) -> GenerationResult:
# TODO: Support singleline outside of torch
koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))
@@ -5651,8 +5653,9 @@ def raw_generate(
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
single_line=single_line,
batch_count=batch_count,
gen_settings=gen_settings
gen_settings=gen_settings,
)
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
start_time = time.time()
@@ -5712,6 +5715,7 @@ def torch_raw_generate(
do_streaming: bool = False,
do_dynamic_wi: bool = False,
single_line: bool = False,
batch_count: int = 1,
):
start_time = time.time()
@@ -5737,6 +5741,8 @@ def torch_raw_generate(
device = get_auxilary_device()
gen_in = gen_in.to(device)
additional_bad_words_ids = [tokenizer.encode("\n")] if single_line else []
with torch.no_grad():
start_time = time.time()
genout = generator(
@@ -5744,7 +5750,7 @@ def torch_raw_generate(
do_sample=True,
max_length=min(len(prompt_tokens) + max_new, koboldai_vars.max_length),
repetition_penalty=1.0,
bad_words_ids=koboldai_vars.badwordsids,
bad_words_ids=koboldai_vars.badwordsids + additional_bad_words_ids,
use_cache=True,
num_return_sequences=batch_count,
)