Not quite

This commit is contained in:
somebody
2023-05-28 14:57:45 -05:00
parent ed0728188a
commit ceaefa9f5e
4 changed files with 25 additions and 22 deletions

View File

@@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None:
state_dict[new_key] = state_dict.pop(old_key)
# BEGIN PATCH
# TODO: Based on config
dtype = torch.float16
set_module_kwargs = {"dtype": dtype}
for param_name, param in sorted(
state_dict.items(),
# State dict must be ordered in this manner to make the caching in
@@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None:
param_name = param_name[len(start_prefix) :]
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
@@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None:
elif not load_in_8bit:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(
model, param_name, param_device, **set_module_kwargs
model, tensor_name=param_name, device=param_device, **set_module_kwargs
)
else:
if (