diff --git a/modeling/patches.py b/modeling/patches.py index 2b823b3b..203d98c9 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -290,6 +290,7 @@ def patch_transformers_for_lazyload() -> None: tensor_name=param_name, device=param_device, value=param, + dtype=dtype, ) else: if (