Force grad to be off by default when loading with lazy loader

This commit is contained in:
Gnome Ann 2022-04-19 12:26:02 -04:00
parent a82a165146
commit 6803531384
1 changed files with 2 additions and 2 deletions

View File

@ -89,7 +89,7 @@ class LazyTensor:
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor: def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor:
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
@ -106,7 +106,7 @@ class LazyTensor:
storage = torch.serialization._get_restore_location(map_location)(storage, self.location) storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
tensor.set_(storage, 0, self.shape, self.stride) tensor.set_(storage, 0, self.shape, self.stride)
tensor.requires_grad = self.requires_grad tensor.requires_grad = not no_grad and self.requires_grad
tensor._backward_hooks = self.backward_hooks tensor._backward_hooks = self.backward_hooks
return tensor return tensor