Clean up when error is thrown in `use_lazy_torch_load`

This commit is contained in:
Gnome Ann 2022-03-01 19:30:22 -05:00
parent a0344b429c
commit 4fa4dbac50
1 changed files with 18 additions and 16 deletions

View File

@ -49,7 +49,7 @@ class _LazyUnpickler(pickle.Unpickler):
if key not in self.lazy_loaded_storages:
self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements)
return self.lazy_loaded_storages[key]
def load(self, *args, **kwargs):
@ -72,24 +72,26 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None):
yield False
return
old_unpickler = pickle.Unpickler
pickle.Unpickler = _LazyUnpickler
try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = _LazyUnpickler
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor
old_torch_load = torch.load
old_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
if callback is not None:
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
return retval
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
if callback is not None:
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
return retval
torch.load = torch_load
torch.load = torch_load
yield True
yield True
pickle.Unpickler = old_unpickler
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load
finally:
pickle.Unpickler = old_unpickler
torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load