Merge pull request #235 from VE-FORBRYDERNE/patch

Fix materialize function for galactica models
This commit is contained in:
henk717 2022-12-12 20:15:54 +01:00 committed by GitHub
commit 0a926e41e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions

View File

@ -54,6 +54,7 @@ import numpy as np
import collections
import _codecs
import utils
import os
from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -93,12 +94,16 @@ class LazyTensor:
def __repr__(self):
return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor:
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r")
try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
except:
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
f.read(self.seek_offset)
else:
f = checkpoint