diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index df37e0be..401d6ccf 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1255,7 +1255,11 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword params["cores_per_replica"], network.state["params"][spec["module"]][spec["param"]].shape, ) - tensor = jnp.array(tensor.detach()) + tensor = tensor.detach() + # numpy does not support bfloat16 + if tensor.dtype is torch.bfloat16: + tensor = tensor.to(torch.float32) + tensor = jnp.array(tensor) if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: tensor = tensor.bfloat16() network.state["params"][spec["module"]][spec["param"]] = move_xmap(