diff --git a/aiserver.py b/aiserver.py index 299f0349..9160a287 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5580,8 +5580,10 @@ def raw_generate( bypass_hf_maxlength: bool = False, generation_settings: Optional[dict] = None, is_core: bool = False, + single_line: bool = False, found_entries: set = () ) -> GenerationResult: + # TODO: Support singleline outside of torch koboldai_vars.inference_config.do_core = is_core gen_settings = GenerationSettings(*(generation_settings or {})) @@ -5651,8 +5653,9 @@ def raw_generate( max_new=max_new if not bypass_hf_maxlength else int(2e9), do_streaming=do_streaming, do_dynamic_wi=do_dynamic_wi, + single_line=single_line, batch_count=batch_count, - gen_settings=gen_settings + gen_settings=gen_settings, ) logger.debug("raw_generate: run torch_raw_generate {}s".format(time.time()-start_time)) start_time = time.time() @@ -5712,6 +5715,7 @@ def torch_raw_generate( do_streaming: bool = False, do_dynamic_wi: bool = False, + single_line: bool = False, batch_count: int = 1, ): start_time = time.time() @@ -5737,6 +5741,8 @@ def torch_raw_generate( device = get_auxilary_device() gen_in = gen_in.to(device) + additional_bad_words_ids = [tokenizer.encode("\n")] if single_line else [] + with torch.no_grad(): start_time = time.time() genout = generator( @@ -5744,7 +5750,7 @@ def torch_raw_generate( do_sample=True, max_length=min(len(prompt_tokens) + max_new, koboldai_vars.max_length), repetition_penalty=1.0, - bad_words_ids=koboldai_vars.badwordsids, + bad_words_ids=koboldai_vars.badwordsids + additional_bad_words_ids, use_cache=True, num_return_sequences=batch_count, )