From 18dc7069447e7b7284e461f0d689cf0ad03cbd17 Mon Sep 17 00:00:00 2001 From: Bogdan Drema Date: Sat, 3 Jun 2023 12:21:52 +0100 Subject: [PATCH] fix: tpu support models saved in bfloat16 --- tpu_mtj_backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index df37e0be..401d6ccf 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1255,7 +1255,11 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword params["cores_per_replica"], network.state["params"][spec["module"]][spec["param"]].shape, ) - tensor = jnp.array(tensor.detach()) + tensor = tensor.detach() + # numpy does not support bfloat16 + if tensor.dtype is torch.bfloat16: + tensor = tensor.to(torch.float32) + tensor = jnp.array(tensor) if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: tensor = tensor.bfloat16() network.state["params"][spec["module"]][spec["param"]] = move_xmap(