Model: Fix TPU

This commit is contained in:
somebody
2023-03-01 19:40:52 -06:00
parent f2974d205e
commit 27b7635c95
8 changed files with 65 additions and 21 deletions

View File

@@ -330,6 +330,11 @@ class InferenceModel:
# Real max length is handled by CoreStopper. # Real max length is handled by CoreStopper.
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan, bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
is_core=True, is_core=True,
tpu_dynamic_inference=utils.koboldai_vars.dynamicscan
or (
not utils.koboldai_vars.nogenmod
and utils.koboldai_vars.has_genmod
),
) )
logger.debug( logger.debug(
"core_generate: run raw_generate pass {} {}s".format( "core_generate: run raw_generate pass {} {}s".format(
@@ -473,6 +478,7 @@ class InferenceModel:
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs,
) -> GenerationResult: ) -> GenerationResult:
"""Lowest level model-agnostic generation function. To be overridden by model implementation. """Lowest level model-agnostic generation function. To be overridden by model implementation.
@@ -501,6 +507,8 @@ class InferenceModel:
is_core: bool = False, is_core: bool = False,
single_line: bool = False, single_line: bool = False,
found_entries: set = (), found_entries: set = (),
tpu_dynamic_inference: bool = False,
**kwargs,
) -> GenerationResult: ) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
@@ -563,6 +571,7 @@ class InferenceModel:
batch_count=batch_count, batch_count=batch_count,
gen_settings=gen_settings, gen_settings=gen_settings,
single_line=single_line, single_line=single_line,
tpu_dynamic_inference=tpu_dynamic_inference,
) )
time_end = round(time.time() - time_start, 2) time_end = round(time.time() - time_start, 2)

View File

@@ -35,6 +35,7 @@ class APIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
): ):
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -29,6 +29,7 @@ class ColabInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
): ):
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -9,10 +9,16 @@ from typing import Union
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
import utils import utils
import breakmodel
import torch_lazy_loader import torch_lazy_loader
import koboldai_settings import koboldai_settings
try:
import breakmodel
except ModuleNotFoundError as e:
# Breakmodel is only expected to work on GPU
if not utils.koboldai_vars.use_colab_tpu:
raise e
from modeling.inference_models.hf_torch import HFTorchInferenceModel from modeling.inference_models.hf_torch import HFTorchInferenceModel

View File

@@ -10,7 +10,11 @@ import utils
import koboldai_settings import koboldai_settings
from logger import logger, Colors from logger import logger, Colors
from modeling.inference_model import ModelCapabilities from modeling.inference_model import (
GenerationResult,
GenerationSettings,
ModelCapabilities,
)
from modeling.inference_models.hf import HFInferenceModel from modeling.inference_models.hf import HFInferenceModel
try: try:
@@ -257,9 +261,14 @@ class HFMTJInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
) -> GenerationResult: ) -> GenerationResult:
soft_tokens = self.get_soft_tokens() soft_tokens = self.get_soft_tokens()
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}")
if not dynamic_inference:
genout = tpool.execute( genout = tpool.execute(
tpu_mtj_backend.infer_static, tpu_mtj_backend.infer_static,
np.uint32(prompt_tokens), np.uint32(prompt_tokens),
@@ -279,6 +288,21 @@ class HFMTJInferenceModel(HFInferenceModel):
sampler_order=gen_settings.sampler_order, sampler_order=gen_settings.sampler_order,
) )
genout = np.array(genout) genout = np.array(genout)
else:
genout = tpool.execute(
tpu_mtj_backend.infer_dynamic,
context=np.uint32(prompt_tokens),
numseqs=batch_count,
gen_len=max_new,
soft_embeddings=utils.koboldai_vars.sp,
soft_tokens=soft_tokens,
# TODO: Fix Dynamic WI on TPU
excluded_world_info=set(),
use_callback=True
)
print(genout)
print(type(genout))
genout = np.array(genout)
return GenerationResult( return GenerationResult(
self, self,

View File

@@ -448,6 +448,7 @@ class HFTorchInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
) -> GenerationResult: ) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor): if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]

View File

@@ -35,6 +35,7 @@ class HordeInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
) -> GenerationResult: ) -> GenerationResult:
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -28,6 +28,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs
) -> GenerationResult: ) -> GenerationResult:
# Taken mainly from oairequest() # Taken mainly from oairequest()