This commit is contained in:
ebolam
2022-09-25 20:30:21 -04:00

View File

@@ -5178,6 +5178,9 @@ def raw_generate(
if koboldai_vars.model == "ReadOnly":
raise NotImplementedError("No loaded model")
result: GenerationResult
time_start = time.time()
if koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"):
batch_encoded = tpu_raw_generate(
prompt_tokens=prompt_tokens,
@@ -5185,7 +5188,7 @@ def raw_generate(
batch_count=batch_count,
gen_settings=gen_settings
)
return GenerationResult(
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model in model_functions:
@@ -5195,7 +5198,7 @@ def raw_generate(
batch_count=batch_count,
gen_settings=gen_settings
)
return GenerationResult(
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
@@ -5205,25 +5208,33 @@ def raw_generate(
batch_count=batch_count,
gen_settings=gen_settings
)
return GenerationResult(
result = GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
else:
# Torch HF
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
result = GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
# Torch HF
batch_encoded = torch_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new if not bypass_hf_maxlength else int(2e9),
do_streaming=do_streaming,
do_dynamic_wi=do_dynamic_wi,
batch_count=batch_count,
gen_settings=gen_settings
)
return GenerationResult(
out_batches=batch_encoded,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
if not koboldai_vars.quiet:
logger.info(f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second.")
return result
def tpu_raw_generate(
prompt_tokens: List[int],