diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 49a30931..789e56b4 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -89,7 +89,7 @@ class LazyTensor: def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: + def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor: size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) @@ -106,7 +106,7 @@ class LazyTensor: storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor.set_(storage, 0, self.shape, self.stride) - tensor.requires_grad = self.requires_grad + tensor.requires_grad = not no_grad and self.requires_grad tensor._backward_hooks = self.backward_hooks return tensor