mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Fix TPU
This commit is contained in:
@@ -330,6 +330,11 @@ class InferenceModel:
|
|||||||
# Real max length is handled by CoreStopper.
|
# Real max length is handled by CoreStopper.
|
||||||
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
|
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
|
||||||
is_core=True,
|
is_core=True,
|
||||||
|
tpu_dynamic_inference=utils.koboldai_vars.dynamicscan
|
||||||
|
or (
|
||||||
|
not utils.koboldai_vars.nogenmod
|
||||||
|
and utils.koboldai_vars.has_genmod
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"core_generate: run raw_generate pass {} {}s".format(
|
"core_generate: run raw_generate pass {} {}s".format(
|
||||||
@@ -473,6 +478,7 @@ class InferenceModel:
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
||||||
|
|
||||||
@@ -501,6 +507,8 @@ class InferenceModel:
|
|||||||
is_core: bool = False,
|
is_core: bool = False,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
found_entries: set = (),
|
found_entries: set = (),
|
||||||
|
tpu_dynamic_inference: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||||
|
|
||||||
@@ -563,6 +571,7 @@ class InferenceModel:
|
|||||||
batch_count=batch_count,
|
batch_count=batch_count,
|
||||||
gen_settings=gen_settings,
|
gen_settings=gen_settings,
|
||||||
single_line=single_line,
|
single_line=single_line,
|
||||||
|
tpu_dynamic_inference=tpu_dynamic_inference,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_end = round(time.time() - time_start, 2)
|
time_end = round(time.time() - time_start, 2)
|
||||||
|
@@ -35,6 +35,7 @@ class APIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -29,6 +29,7 @@ class ColabInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -9,10 +9,16 @@ from typing import Union
|
|||||||
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
import breakmodel
|
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
import breakmodel
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
# Breakmodel is only expected to work on GPU
|
||||||
|
if not utils.koboldai_vars.use_colab_tpu:
|
||||||
|
raise e
|
||||||
|
|
||||||
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
@@ -10,7 +10,11 @@ import utils
|
|||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
from logger import logger, Colors
|
from logger import logger, Colors
|
||||||
|
|
||||||
from modeling.inference_model import ModelCapabilities
|
from modeling.inference_model import (
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSettings,
|
||||||
|
ModelCapabilities,
|
||||||
|
)
|
||||||
from modeling.inference_models.hf import HFInferenceModel
|
from modeling.inference_models.hf import HFInferenceModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -257,9 +261,14 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
soft_tokens = self.get_soft_tokens()
|
soft_tokens = self.get_soft_tokens()
|
||||||
|
|
||||||
|
dynamic_inference = kwargs.get("tpu_dynamic_inference", False)
|
||||||
|
print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}")
|
||||||
|
|
||||||
|
if not dynamic_inference:
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
tpu_mtj_backend.infer_static,
|
tpu_mtj_backend.infer_static,
|
||||||
np.uint32(prompt_tokens),
|
np.uint32(prompt_tokens),
|
||||||
@@ -279,6 +288,21 @@ class HFMTJInferenceModel(HFInferenceModel):
|
|||||||
sampler_order=gen_settings.sampler_order,
|
sampler_order=gen_settings.sampler_order,
|
||||||
)
|
)
|
||||||
genout = np.array(genout)
|
genout = np.array(genout)
|
||||||
|
else:
|
||||||
|
genout = tpool.execute(
|
||||||
|
tpu_mtj_backend.infer_dynamic,
|
||||||
|
context=np.uint32(prompt_tokens),
|
||||||
|
numseqs=batch_count,
|
||||||
|
gen_len=max_new,
|
||||||
|
soft_embeddings=utils.koboldai_vars.sp,
|
||||||
|
soft_tokens=soft_tokens,
|
||||||
|
# TODO: Fix Dynamic WI on TPU
|
||||||
|
excluded_world_info=set(),
|
||||||
|
use_callback=True
|
||||||
|
)
|
||||||
|
print(genout)
|
||||||
|
print(type(genout))
|
||||||
|
genout = np.array(genout)
|
||||||
|
|
||||||
return GenerationResult(
|
return GenerationResult(
|
||||||
self,
|
self,
|
||||||
|
@@ -448,6 +448,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
if not isinstance(prompt_tokens, torch.Tensor):
|
if not isinstance(prompt_tokens, torch.Tensor):
|
||||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||||
|
@@ -35,6 +35,7 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -28,6 +28,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
|
**kwargs
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
# Taken mainly from oairequest()
|
# Taken mainly from oairequest()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user