mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #370 from Zurnaz/tpu_bfloat16
fix: tpu support models saved in bfloat16
This commit is contained in:
@@ -1255,7 +1255,11 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
|
||||
params["cores_per_replica"],
|
||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||
)
|
||||
tensor = jnp.array(tensor.detach())
|
||||
tensor = tensor.detach()
|
||||
# numpy does not support bfloat16
|
||||
if tensor.dtype is torch.bfloat16:
|
||||
tensor = tensor.to(torch.float32)
|
||||
tensor = jnp.array(tensor)
|
||||
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
|
||||
tensor = tensor.bfloat16()
|
||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||
|
Reference in New Issue
Block a user