Merge pull request #370 from Zurnaz/tpu_bfloat16

fix: tpu support models saved in bfloat16
This commit is contained in:
henk717
2023-06-03 14:06:21 +02:00
committed by GitHub

View File

@@ -1255,7 +1255,11 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
params["cores_per_replica"], params["cores_per_replica"],
network.state["params"][spec["module"]][spec["param"]].shape, network.state["params"][spec["module"]][spec["param"]].shape,
) )
tensor = jnp.array(tensor.detach()) tensor = tensor.detach()
# numpy does not support bfloat16
if tensor.dtype is torch.bfloat16:
tensor = tensor.to(torch.float32)
tensor = jnp.array(tensor)
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32: if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
tensor = tensor.bfloat16() tensor = tensor.bfloat16()
network.state["params"][spec["module"]][spec["param"]] = move_xmap( network.state["params"][spec["module"]][spec["param"]] = move_xmap(