Clean up when error is thrown in `use_lazy_torch_load`
This commit is contained in:
parent
a0344b429c
commit
4fa4dbac50
|
@ -49,7 +49,7 @@ class _LazyUnpickler(pickle.Unpickler):
|
|||
|
||||
if key not in self.lazy_loaded_storages:
|
||||
self.lazy_loaded_storages[key] = LazyTensor(storage_type, key, location, nelements)
|
||||
|
||||
|
||||
return self.lazy_loaded_storages[key]
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
|
@ -72,24 +72,26 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None):
|
|||
yield False
|
||||
return
|
||||
|
||||
old_unpickler = pickle.Unpickler
|
||||
pickle.Unpickler = _LazyUnpickler
|
||||
try:
|
||||
old_unpickler = pickle.Unpickler
|
||||
pickle.Unpickler = _LazyUnpickler
|
||||
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
|
||||
old_torch_load = torch.load
|
||||
old_torch_load = torch.load
|
||||
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
if callback is not None:
|
||||
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
return retval
|
||||
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
if callback is not None:
|
||||
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
|
||||
return retval
|
||||
|
||||
torch.load = torch_load
|
||||
torch.load = torch_load
|
||||
|
||||
yield True
|
||||
yield True
|
||||
|
||||
pickle.Unpickler = old_unpickler
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
finally:
|
||||
pickle.Unpickler = old_unpickler
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
|
|
Loading…
Reference in New Issue