mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Not quite
This commit is contained in:
@@ -184,6 +184,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# BEGIN PATCH
|
||||
# TODO: Based on config
|
||||
dtype = torch.float16
|
||||
set_module_kwargs = {"dtype": dtype}
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
# State dict must be ordered in this manner to make the caching in
|
||||
@@ -211,7 +215,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
param_name = param_name[len(start_prefix) :]
|
||||
|
||||
module_name = param_name
|
||||
set_module_kwargs = {}
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
@@ -272,7 +275,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
elif not load_in_8bit:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, **set_module_kwargs
|
||||
model, tensor_name=param_name, device=param_device, **set_module_kwargs
|
||||
)
|
||||
else:
|
||||
if (
|
||||
|
Reference in New Issue
Block a user