diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index c5211050..604f6e69 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -49,7 +49,7 @@ class _LazyUnpickler(pickle.Unpickler): if key not in self.lazy_loaded_storages: self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements) - + return self.lazy_loaded_storages[key] def load(self, *args, **kwargs): @@ -72,24 +72,26 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): yield False return - old_unpickler = pickle.Unpickler - pickle.Unpickler = _LazyUnpickler + try: + old_unpickler = pickle.Unpickler + pickle.Unpickler = _LazyUnpickler - old_rebuild_tensor = torch._utils._rebuild_tensor - torch._utils._rebuild_tensor = _rebuild_tensor + old_rebuild_tensor = torch._utils._rebuild_tensor + torch._utils._rebuild_tensor = _rebuild_tensor - old_torch_load = torch.load + old_torch_load = torch.load - def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): - retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - if callback is not None: - callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) - return retval + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval - torch.load = torch_load + torch.load = torch_load - yield True + yield True - pickle.Unpickler = old_unpickler - torch._utils._rebuild_tensor = old_rebuild_tensor - torch.load = old_torch_load + finally: + pickle.Unpickler = old_unpickler + torch._utils._rebuild_tensor = old_rebuild_tensor + torch.load = old_torch_load