mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #165 from one-some/ui2-speeeeeeeed
Add time info to generations in log
This commit is contained in:
47
aiserver.py
47
aiserver.py
@@ -5178,6 +5178,9 @@ def raw_generate(
|
||||
if koboldai_vars.model == "ReadOnly":
|
||||
raise NotImplementedError("No loaded model")
|
||||
|
||||
result: GenerationResult
|
||||
time_start = time.time()
|
||||
|
||||
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
|
||||
batch_encoded = tpu_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@@ -5185,7 +5188,7 @@ def raw_generate(
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
return GenerationResult(
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model in model_functions:
|
||||
@@ -5195,7 +5198,7 @@ def raw_generate(
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
return GenerationResult(
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
|
||||
)
|
||||
elif koboldai_vars.model.startswith("RWKV"):
|
||||
@@ -5205,25 +5208,33 @@ def raw_generate(
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
return GenerationResult(
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
|
||||
)
|
||||
else:
|
||||
# Torch HF
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
result = GenerationResult(
|
||||
out_batches=batch_encoded,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||
|
||||
# Torch HF
|
||||
batch_encoded = torch_raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new if not bypass_hf_maxlength else int(2e9),
|
||||
do_streaming=do_streaming,
|
||||
do_dynamic_wi=do_dynamic_wi,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings
|
||||
)
|
||||
return GenerationResult(
|
||||
out_batches=batch_encoded,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
if not koboldai_vars.quiet:
|
||||
logger.info(f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second.")
|
||||
|
||||
return result
|
||||
|
||||
def tpu_raw_generate(
|
||||
prompt_tokens: List[int],
|
||||
|
Reference in New Issue
Block a user