Merge pull request #235 from VE-FORBRYDERNE/patch

Fix materialize function for galactica models
This commit is contained in:
henk717 2022-12-12 20:15:54 +01:00 committed by GitHub
commit 0a926e41e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions

View File

@ -54,6 +54,7 @@ import numpy as np
import collections import collections
import _codecs import _codecs
import utils import utils
import os
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -93,12 +94,16 @@ class LazyTensor:
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor: def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile): if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r") try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
except:
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
f.read(self.seek_offset) f.read(self.seek_offset)
else: else:
f = checkpoint f = checkpoint