mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Fix TPU
This commit is contained in:
@@ -330,6 +330,11 @@ class InferenceModel:
|
||||
# Real max length is handled by CoreStopper.
|
||||
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
|
||||
is_core=True,
|
||||
tpu_dynamic_inference=utils.koboldai_vars.dynamicscan
|
||||
or (
|
||||
not utils.koboldai_vars.nogenmod
|
||||
and utils.koboldai_vars.has_genmod
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
"core_generate: run raw_generate pass {} {}s".format(
|
||||
@@ -473,6 +478,7 @@ class InferenceModel:
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
||||
|
||||
@@ -501,6 +507,8 @@ class InferenceModel:
|
||||
is_core: bool = False,
|
||||
single_line: bool = False,
|
||||
found_entries: set = (),
|
||||
tpu_dynamic_inference: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||
|
||||
@@ -563,6 +571,7 @@ class InferenceModel:
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings,
|
||||
single_line=single_line,
|
||||
tpu_dynamic_inference=tpu_dynamic_inference,
|
||||
)
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
|
@@ -35,6 +35,7 @@ class APIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
):
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -29,6 +29,7 @@ class ColabInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
):
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -9,10 +9,16 @@ from typing import Union
|
||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
|
||||
import utils
|
||||
import breakmodel
|
||||
import torch_lazy_loader
|
||||
import koboldai_settings
|
||||
|
||||
try:
|
||||
import breakmodel
|
||||
except ModuleNotFoundError as e:
|
||||
# Breakmodel is only expected to work on GPU
|
||||
if not utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||
|
||||
|
||||
|
@@ -10,7 +10,11 @@ import utils
|
||||
import koboldai_settings
|
||||
from logger import logger, Colors
|
||||
|
||||
from modeling.inference_model import ModelCapabilities
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
ModelCapabilities,
|
||||
)
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
|
||||
try:
|
||||
@@ -257,9 +261,14 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
) -> GenerationResult:
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
|
||||
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||
print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}")
|
||||
|
||||
if not dynamic_inference:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
np.uint32(prompt_tokens),
|
||||
@@ -279,6 +288,21 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
sampler_order=gen_settings.sampler_order,
|
||||
)
|
||||
genout = np.array(genout)
|
||||
else:
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_dynamic,
|
||||
context=np.uint32(prompt_tokens),
|
||||
numseqs=batch_count,
|
||||
gen_len=max_new,
|
||||
soft_embeddings=utils.koboldai_vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
# TODO: Fix Dynamic WI on TPU
|
||||
excluded_world_info=set(),
|
||||
use_callback=True
|
||||
)
|
||||
print(genout)
|
||||
print(type(genout))
|
||||
genout = np.array(genout)
|
||||
|
||||
return GenerationResult(
|
||||
self,
|
||||
|
@@ -448,6 +448,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
|
@@ -35,6 +35,7 @@ class HordeInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
) -> GenerationResult:
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -28,6 +28,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
) -> GenerationResult:
|
||||
# Taken mainly from oairequest()
|
||||
|
||||
|
Reference in New Issue
Block a user