From 636207bfacfe7a82ec0449080ebbc63e62fa82e5 Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 17 Sep 2022 20:33:38 -0500 Subject: [PATCH] Gen gen gen --- aiserver.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/aiserver.py b/aiserver.py index b219dfd2..aa93a301 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1807,11 +1807,13 @@ def patch_transformers(): scores: torch.FloatTensor, **kwargs, ) -> bool: - if not koboldai_vars.inference_config.do_dynamic_wi: + if not koboldai_vars.inference_config.do_streaming: return False if not koboldai_vars.output_streaming: return False + + print([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids]) koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids]) @@ -4617,23 +4619,17 @@ def legacy_generate(text: Union[str, list], min: int, max: int): koboldai_vars.lastctx = text - print("Pregen") - print(koboldai_vars.max_length) outputs = raw_generate( text, max_length=koboldai_vars.genamt, do_streaming=True ) - print(f"postgen: {outputs}") # Lua bridge, genmod for i, output in enumerate(outputs): koboldai_vars.lua_koboldbridge.outputs[i + 1] = output - print("post lua") - execute_genmod() - print("post genmod") if koboldai_vars.lua_koboldbridge.regeneration_required: koboldai_vars.lua_koboldbridge.regeneration_required = False @@ -4644,14 +4640,10 @@ def legacy_generate(text: Union[str, list], min: int, max: int): assert isinstance(out, str) else: genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs] - - print("post assign genout") koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout]) genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()] - print("post genout assign") - if len(genout) == 1: genresult(genout[0]["generated_text"]) else: @@ -4660,9 +4652,7 @@ def legacy_generate(text: Union[str, list], min: int, max: int): genresult(genout[restart_seq - 1]["generated_text"]) else: genselect(genout) - print("post whatever that is") set_aibusy(0) - print("post busy") def raw_generate( # prompt is either a string (text) or a list (token ids) @@ -4699,13 +4689,15 @@ def raw_generate( else: batch_out = torch_raw_generate( prompt_tokens=prompt_tokens, - max_length=max_length, + max_new=max_length, do_streaming=do_streaming, do_dynamic_wi=do_dynamic_wi, batch_count=batch_count ) + + decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):]) - return [utils.decodenewlines(tokenizer.decode(x)) for x in batch_out] + return [utils.decodenewlines(x) for x in decoded] def tpu_raw_generate( prompt_tokens: List[int], @@ -4737,7 +4729,7 @@ def tpu_raw_generate( def torch_raw_generate( prompt_tokens: List[int], - max_length: int, + max_new: int, do_streaming: bool = False, do_dynamic_wi: bool = False, @@ -4759,21 +4751,16 @@ def torch_raw_generate( device = breakmodel.primary_device gen_in = gen_in.to(device) - print("okay...") - with torch.no_grad(): - print(f"in {max_length}") genout = generator( gen_in, do_sample=True, - max_length=max_length, + max_length=min(len(prompt_tokens) + max_new, koboldai_vars.max_length), repetition_penalty=1.0, bad_words_ids=koboldai_vars.badwordsids, use_cache=True, num_return_sequences=batch_count, ) - print("out") - print("wtf") return genout