diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index 36a32cea..3dee5bae 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -389,7 +389,7 @@ def safetensors_load_tensor_independently( return f.get_tensor(tensor_key) -def patch_safetensors(): +def patch_safetensors(callback): # Safetensors load patch import transformers @@ -494,7 +494,7 @@ def use_lazy_load( torch.load = torch_load if HAS_SAFETENSORS: - patch_safetensors() + patch_safetensors(callback) if dematerialized_modules: if use_accelerate_init_empty_weights: