mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
TPU Fixes
This commit is contained in:
@@ -186,6 +186,7 @@ class model_backend(HFInferenceModel):
|
|||||||
|
|
||||||
tpu_mtj_backend.load_model(
|
tpu_mtj_backend.load_model(
|
||||||
utils.koboldai_vars.model,
|
utils.koboldai_vars.model,
|
||||||
|
self.model_type,
|
||||||
hf_checkpoint=utils.koboldai_vars.model
|
hf_checkpoint=utils.koboldai_vars.model
|
||||||
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||||
and utils.koboldai_vars.use_colab_tpu,
|
and utils.koboldai_vars.use_colab_tpu,
|
||||||
@@ -202,7 +203,7 @@ class model_backend(HFInferenceModel):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
||||||
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
and self.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
):
|
):
|
||||||
utils.koboldai_vars.badwordsids = [
|
utils.koboldai_vars.badwordsids = [
|
||||||
[v]
|
[v]
|
||||||
|
@@ -941,7 +941,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
|||||||
|
|
||||||
koboldai_vars.status_message = ""
|
koboldai_vars.status_message = ""
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
def load_model(path: str, model_type: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||||
|
|
||||||
if kwargs.get("pad_token_id"):
|
if kwargs.get("pad_token_id"):
|
||||||
@@ -989,9 +989,9 @@ def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=Fa
|
|||||||
|
|
||||||
# Try to convert HF config.json to MTJ config
|
# Try to convert HF config.json to MTJ config
|
||||||
if hf_checkpoint:
|
if hf_checkpoint:
|
||||||
spec_path = os.path.join("maps", koboldai_vars.model_type + ".json")
|
spec_path = os.path.join("maps", model_type + ".json")
|
||||||
if not os.path.isfile(spec_path):
|
if not os.path.isfile(spec_path):
|
||||||
raise NotImplementedError(f"Unsupported model type {repr(koboldai_vars.model_type)}")
|
raise NotImplementedError(f"Unsupported model type {repr(model_type)}")
|
||||||
with open(spec_path) as f:
|
with open(spec_path) as f:
|
||||||
lazy_load_spec = json.load(f)
|
lazy_load_spec = json.load(f)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user