From 5fe8c71b2ed9132ca591d3797d1deca6f8e8762e Mon Sep 17 00:00:00 2001 From: ebolam Date: Wed, 24 May 2023 18:55:31 -0400 Subject: [PATCH] TPU Fixes --- modeling/inference_models/hf_mtj/class.py | 3 ++- tpu_mtj_backend.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modeling/inference_models/hf_mtj/class.py b/modeling/inference_models/hf_mtj/class.py index 4de3a1b2..876e950e 100644 --- a/modeling/inference_models/hf_mtj/class.py +++ b/modeling/inference_models/hf_mtj/class.py @@ -186,6 +186,7 @@ class model_backend(HFInferenceModel): tpu_mtj_backend.load_model( utils.koboldai_vars.model, + self.model_type, hf_checkpoint=utils.koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and utils.koboldai_vars.use_colab_tpu, @@ -202,7 +203,7 @@ class model_backend(HFInferenceModel): if ( utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default - and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") + and self.model_type not in ("gpt2", "gpt_neo", "gptj") ): utils.koboldai_vars.badwordsids = [ [v] diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 07261636..d5a4d1db 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -941,7 +941,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): koboldai_vars.status_message = "" -def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: +def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params, pad_token_id if kwargs.get("pad_token_id"): @@ -989,9 +989,9 @@ def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=Fa # Try to convert HF config.json to MTJ config if hf_checkpoint: - spec_path = os.path.join("maps", koboldai_vars.model_type + ".json") + spec_path = os.path.join("maps", model_type + ".json") if not os.path.isfile(spec_path): - raise NotImplementedError(f"Unsupported model type {repr(koboldai_vars.model_type)}") + raise NotImplementedError(f"Unsupported model type {repr(model_type)}") with open(spec_path) as f: lazy_load_spec = json.load(f)