From bd8658404be91bbaadac28e28f75b291a074e715 Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 25 Sep 2022 19:04:09 -0500 Subject: [PATCH] Add time info to generations --- aiserver.py | 47 +++++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0d919712..572d8b5f 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5178,6 +5178,9 @@ def raw_generate( if koboldai_vars.model == "ReadOnly": raise NotImplementedError("No loaded model") + result: GenerationResult + time_start = time.time() + if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): batch_encoded = tpu_raw_generate( prompt_tokens=prompt_tokens, @@ -5185,7 +5188,7 @@ def raw_generate( batch_count=batch_count, gen_settings=gen_settings ) - return GenerationResult( + result = GenerationResult( out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True ) elif koboldai_vars.model in model_functions: @@ -5195,7 +5198,7 @@ def raw_generate( batch_count=batch_count, gen_settings=gen_settings ) - return GenerationResult( + result = GenerationResult( out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True ) elif koboldai_vars.model.startswith("RWKV"): @@ -5205,25 +5208,33 @@ def raw_generate( batch_count=batch_count, gen_settings=gen_settings ) - return GenerationResult( + result = GenerationResult( out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True ) + else: + # Torch HF + batch_encoded = torch_raw_generate( + prompt_tokens=prompt_tokens, + max_new=max_new if not bypass_hf_maxlength else int(2e9), + do_streaming=do_streaming, + do_dynamic_wi=do_dynamic_wi, + batch_count=batch_count, + gen_settings=gen_settings + ) + result = GenerationResult( + out_batches=batch_encoded, + prompt=prompt_tokens, + is_whole_generation=False, + output_includes_prompt=True, + ) + + time_end = round(time.time() - time_start, 2) + tokens_per_second = round(len(result.encoded[0]) / time_end, 2) - # Torch HF - batch_encoded = torch_raw_generate( - prompt_tokens=prompt_tokens, - max_new=max_new if not bypass_hf_maxlength else int(2e9), - do_streaming=do_streaming, - do_dynamic_wi=do_dynamic_wi, - batch_count=batch_count, - gen_settings=gen_settings - ) - return GenerationResult( - out_batches=batch_encoded, - prompt=prompt_tokens, - is_whole_generation=False, - output_includes_prompt=True, - ) + if not koboldai_vars.quiet: + logger.info(f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second.") + + return result def tpu_raw_generate( prompt_tokens: List[int],