Model: Fix TPU

This commit is contained in:
somebody
2023-03-01 19:40:52 -06:00
parent f2974d205e
commit 27b7635c95
8 changed files with 65 additions and 21 deletions

View File

@@ -10,7 +10,11 @@ import utils
import koboldai_settings
from logger import logger, Colors
from modeling.inference_model import ModelCapabilities
from modeling.inference_model import (
GenerationResult,
GenerationSettings,
ModelCapabilities,
)
from modeling.inference_models.hf import HFInferenceModel
try:
@@ -257,28 +261,48 @@ class HFMTJInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
**kwargs
) -> GenerationResult:
soft_tokens = self.get_soft_tokens()
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(prompt_tokens),
gen_len=max_new,
temp=gen_settings.temp,
top_p=gen_settings.top_p,
top_k=gen_settings.top_k,
tfs=gen_settings.tfs,
typical=gen_settings.typical,
top_a=gen_settings.top_a,
numseqs=batch_count,
repetition_penalty=gen_settings.rep_pen,
rpslope=gen_settings.rep_pen_slope,
rprange=gen_settings.rep_pen_range,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
sampler_order=gen_settings.sampler_order,
)
genout = np.array(genout)
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}")
if not dynamic_inference:
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(prompt_tokens),
gen_len=max_new,
temp=gen_settings.temp,
top_p=gen_settings.top_p,
top_k=gen_settings.top_k,
tfs=gen_settings.tfs,
typical=gen_settings.typical,
top_a=gen_settings.top_a,
numseqs=batch_count,
repetition_penalty=gen_settings.rep_pen,
rpslope=gen_settings.rep_pen_slope,
rprange=gen_settings.rep_pen_range,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
sampler_order=gen_settings.sampler_order,
)
genout = np.array(genout)
else:
genout = tpool.execute(
tpu_mtj_backend.infer_dynamic,
context=np.uint32(prompt_tokens),
numseqs=batch_count,
gen_len=max_new,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
# TODO: Fix Dynamic WI on TPU
excluded_world_info=set(),
use_callback=True
)
print(genout)
print(type(genout))
genout = np.array(genout)
return GenerationResult(
self,