diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 8b7b9114..3f054f79 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -330,6 +330,11 @@ class InferenceModel: # Real max length is handled by CoreStopper. bypass_hf_maxlength=utils.koboldai_vars.dynamicscan, is_core=True, + tpu_dynamic_inference=utils.koboldai_vars.dynamicscan + or ( + not utils.koboldai_vars.nogenmod + and utils.koboldai_vars.has_genmod + ), ) logger.debug( "core_generate: run raw_generate pass {} {}s".format( @@ -473,6 +478,7 @@ class InferenceModel: gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs, ) -> GenerationResult: """Lowest level model-agnostic generation function. To be overridden by model implementation. @@ -501,6 +507,8 @@ class InferenceModel: is_core: bool = False, single_line: bool = False, found_entries: set = (), + tpu_dynamic_inference: bool = False, + **kwargs, ) -> GenerationResult: """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. @@ -563,6 +571,7 @@ class InferenceModel: batch_count=batch_count, gen_settings=gen_settings, single_line=single_line, + tpu_dynamic_inference=tpu_dynamic_inference, ) time_end = round(time.time() - time_start, 2) diff --git a/modeling/inference_models/api.py b/modeling/inference_models/api.py index 852ec01d..9fc98abb 100644 --- a/modeling/inference_models/api.py +++ b/modeling/inference_models/api.py @@ -35,6 +35,7 @@ class APIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ): decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/colab.py b/modeling/inference_models/colab.py index 87358e41..faf06299 100644 --- a/modeling/inference_models/colab.py +++ b/modeling/inference_models/colab.py @@ -29,6 +29,7 @@ class ColabInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ): decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index d7372814..50a64956 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -9,10 +9,16 @@ from typing import Union from transformers import AutoModelForCausalLM, GPTNeoForCausalLM import utils -import breakmodel import torch_lazy_loader import koboldai_settings +try: + import breakmodel +except ModuleNotFoundError as e: + # Breakmodel is only expected to work on GPU + if not utils.koboldai_vars.use_colab_tpu: + raise e + from modeling.inference_models.hf_torch import HFTorchInferenceModel diff --git a/modeling/inference_models/hf_mtj.py b/modeling/inference_models/hf_mtj.py index 19fed474..984123b2 100644 --- a/modeling/inference_models/hf_mtj.py +++ b/modeling/inference_models/hf_mtj.py @@ -10,7 +10,11 @@ import utils import koboldai_settings from logger import logger, Colors -from modeling.inference_model import ModelCapabilities +from modeling.inference_model import ( + GenerationResult, + GenerationSettings, + ModelCapabilities, +) from modeling.inference_models.hf import HFInferenceModel try: @@ -257,28 +261,48 @@ class HFMTJInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ) -> GenerationResult: soft_tokens = self.get_soft_tokens() - genout = tpool.execute( - tpu_mtj_backend.infer_static, - np.uint32(prompt_tokens), - gen_len=max_new, - temp=gen_settings.temp, - top_p=gen_settings.top_p, - top_k=gen_settings.top_k, - tfs=gen_settings.tfs, - typical=gen_settings.typical, - top_a=gen_settings.top_a, - numseqs=batch_count, - repetition_penalty=gen_settings.rep_pen, - rpslope=gen_settings.rep_pen_slope, - rprange=gen_settings.rep_pen_range, - soft_embeddings=utils.koboldai_vars.sp, - soft_tokens=soft_tokens, - sampler_order=gen_settings.sampler_order, - ) - genout = np.array(genout) + dynamic_inference = kwargs.get("tpu_dynamic_inference", False) + print(f"DYNAMIC_INFERENCE={dynamic_inference} KWARGS={kwargs}") + + if not dynamic_inference: + genout = tpool.execute( + tpu_mtj_backend.infer_static, + np.uint32(prompt_tokens), + gen_len=max_new, + temp=gen_settings.temp, + top_p=gen_settings.top_p, + top_k=gen_settings.top_k, + tfs=gen_settings.tfs, + typical=gen_settings.typical, + top_a=gen_settings.top_a, + numseqs=batch_count, + repetition_penalty=gen_settings.rep_pen, + rpslope=gen_settings.rep_pen_slope, + rprange=gen_settings.rep_pen_range, + soft_embeddings=utils.koboldai_vars.sp, + soft_tokens=soft_tokens, + sampler_order=gen_settings.sampler_order, + ) + genout = np.array(genout) + else: + genout = tpool.execute( + tpu_mtj_backend.infer_dynamic, + context=np.uint32(prompt_tokens), + numseqs=batch_count, + gen_len=max_new, + soft_embeddings=utils.koboldai_vars.sp, + soft_tokens=soft_tokens, + # TODO: Fix Dynamic WI on TPU + excluded_world_info=set(), + use_callback=True + ) + print(genout) + print(type(genout)) + genout = np.array(genout) return GenerationResult( self, diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index ee83259d..94f57272 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -448,6 +448,7 @@ class HFTorchInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ) -> GenerationResult: if not isinstance(prompt_tokens, torch.Tensor): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] diff --git a/modeling/inference_models/horde.py b/modeling/inference_models/horde.py index 1fea9b56..c5498512 100644 --- a/modeling/inference_models/horde.py +++ b/modeling/inference_models/horde.py @@ -35,6 +35,7 @@ class HordeInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ) -> GenerationResult: decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/openai.py b/modeling/inference_models/openai.py index 56f386e3..4a28d5f4 100644 --- a/modeling/inference_models/openai.py +++ b/modeling/inference_models/openai.py @@ -28,6 +28,7 @@ class OpenAIAPIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, + **kwargs ) -> GenerationResult: # Taken mainly from oairequest()