TPU Fixes

This commit is contained in:
ebolam
2023-05-24 18:55:31 -04:00
parent f0f646ae7b
commit 5fe8c71b2e
2 changed files with 5 additions and 4 deletions

View File

@@ -186,6 +186,7 @@ class model_backend(HFInferenceModel):
tpu_mtj_backend.load_model( tpu_mtj_backend.load_model(
utils.koboldai_vars.model, utils.koboldai_vars.model,
self.model_type,
hf_checkpoint=utils.koboldai_vars.model hf_checkpoint=utils.koboldai_vars.model
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
and utils.koboldai_vars.use_colab_tpu, and utils.koboldai_vars.use_colab_tpu,
@@ -202,7 +203,7 @@ class model_backend(HFInferenceModel):
if ( if (
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj") and self.model_type not in ("gpt2", "gpt_neo", "gptj")
): ):
utils.koboldai_vars.badwordsids = [ utils.koboldai_vars.badwordsids = [
[v] [v]

View File

@@ -941,7 +941,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = "" koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None: def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if kwargs.get("pad_token_id"): if kwargs.get("pad_token_id"):
@@ -989,9 +989,9 @@ def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=Fa
# Try to convert HF config.json to MTJ config # Try to convert HF config.json to MTJ config
if hf_checkpoint: if hf_checkpoint:
spec_path = os.path.join("maps", koboldai_vars.model_type + ".json") spec_path = os.path.join("maps", model_type + ".json")
if not os.path.isfile(spec_path): if not os.path.isfile(spec_path):
raise NotImplementedError(f"Unsupported model type {repr(koboldai_vars.model_type)}") raise NotImplementedError(f"Unsupported model type {repr(model_type)}")
with open(spec_path) as f: with open(spec_path) as f:
lazy_load_spec = json.load(f) lazy_load_spec = json.load(f)