diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index d9358442..5ff9655b 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -216,14 +216,6 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate torch.load = torch_load - def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): - retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - if callback is not None: - callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - return retval - - torch.load = torch_load - if dematerialized_modules: old_linear_init = torch.nn.Linear.__init__ old_embedding_init = torch.nn.Embedding.__init__