mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Fix TPU
This commit is contained in:
@@ -10,7 +10,11 @@ import utils
|
||||
import koboldai_settings
|
||||
from logger import logger, Colors
|
||||
|
||||
from modeling.inference_model import ModelCapabilities
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
ModelCapabilities,
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
try:
|
||||
@@ -257,28 +261,48 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
) -> GenerationResult:
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
np.uint32(prompt_tokens),
|
||||
gen_len=max_new,
|
||||
temp=gen_settings.temp,
|
||||
top_p=gen_settings.top_p,
|
||||
top_k=gen_settings.top_k,
|
||||
tfs=gen_settings.tfs,
|
||||
typical=gen_settings.typical,
|
||||
top_a=gen_settings.top_a,
|
||||
numseqs=batch_count,
|
||||
repetition_penalty=gen_settings.rep_pen,
|
||||
rpslope=gen_settings.rep_pen_slope,
|
||||
rprange=gen_settings.rep_pen_range,
|
||||
soft_embeddings=utils.koboldai_vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
sampler_order=gen_settings.sampler_order,
|
||||
)
|
||||
genout = np.array(genout)
|
||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||
print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}")
|
||||
|
||||
if not dynamic_inference:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
np.uint32(prompt_tokens),
|
||||
gen_len=max_new,
|
||||
temp=gen_settings.temp,
|
||||
top_p=gen_settings.top_p,
|
||||
top_k=gen_settings.top_k,
|
||||
tfs=gen_settings.tfs,
|
||||
typical=gen_settings.typical,
|
||||
top_a=gen_settings.top_a,
|
||||
numseqs=batch_count,
|
||||
repetition_penalty=gen_settings.rep_pen,
|
||||
rpslope=gen_settings.rep_pen_slope,
|
||||
rprange=gen_settings.rep_pen_range,
|
||||
soft_embeddings=utils.koboldai_vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
sampler_order=gen_settings.sampler_order,
|
||||
)
|
||||
genout = np.array(genout)
|
||||
else:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_dynamic,
|
||||
context=np.uint32(prompt_tokens),
|
||||
numseqs=batch_count,
|
||||
gen_len=max_new,
|
||||
soft_embeddings=utils.koboldai_vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
# TODO: Fix Dynamic WI on TPU
|
||||
excluded_world_info=set(),
|
||||
use_callback=True
|
||||
)
|
||||
print(genout)
|
||||
print(type(genout))
|
||||
genout = np.array(genout)
|
||||
|
||||
return GenerationResult(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user