mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Basic single line support
This commit is contained in:
10
aiserver.py
10
aiserver.py
@@ -5580,8 +5580,10 @@ def raw_generate(
|
||||
bypass_hf_maxlength: bool = False,
|
||||
generation_settings: Optional[dict] = None,
|
||||
is_core: bool = False,
|
||||
single_line: bool = False,
|
||||
found_entries: set = ()
|
||||
) -> GenerationResult:
|
||||
# TODO: Support singleline outside of torch
|
||||
|
||||
koboldai_vars.inference_config.do_core = is_core
|
||||
gen_settings = GenerationSettings(*(generation_settings or {}))
|
||||
@@ -5651,8 +5653,9 @@ def raw_generate(
|
||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
single_line=single_line,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
gen_settings=gen_settings,
|
||||
)
|
||||
logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time))
|
||||
start_time = time.time()
|
||||
@@ -5712,6 +5715,7 @@ def torch_raw_generate(
|
||||
|
||||
do_streaming: bool = False,
|
||||
do_dynamic_wi: bool = False,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
):
|
||||
start_time = time.time()
|
||||
@@ -5737,6 +5741,8 @@ def torch_raw_generate(
|
||||
device = get_auxilary_device()
|
||||
gen_in = gen_in.to(device)
|
||||
|
||||
additional_bad_words_ids = [tokenizer.encode("\n")] if single_line else []
|
||||
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
genout = generator(
|
||||
@@ -5744,7 +5750,7 @@ def torch_raw_generate(
|
||||
do_sample=True,
|
||||
max_length=min(len(prompt_tokens) + max_new, koboldai_vars.max_length),
|
||||
repetition_penalty=1.0,
|
||||
bad_words_ids=koboldai_vars.badwordsids,
|
||||
bad_words_ids=koboldai_vars.badwordsids + additional_bad_words_ids,
|
||||
use_cache=True,
|
||||
num_return_sequences=batch_count,
|
||||
)
|
||||
|
Reference in New Issue
Block a user