mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #370 from Zurnaz/tpu_bfloat16
fix: tpu support models saved in bfloat16
This commit is contained in:
@@ -1255,7 +1255,11 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
|
|||||||
params["cores_per_replica"],
|
params["cores_per_replica"],
|
||||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||||
)
|
)
|
||||||
tensor = jnp.array(tensor.detach())
|
tensor = tensor.detach()
|
||||||
|
# numpy does not support bfloat16
|
||||||
|
if tensor.dtype is torch.bfloat16:
|
||||||
|
tensor = tensor.to(torch.float32)
|
||||||
|
tensor = jnp.array(tensor)
|
||||||
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||||
tensor = tensor.bfloat16()
|
tensor = tensor.bfloat16()
|
||||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||||
|
Reference in New Issue
Block a user