Fix materialize function for galactica models

This commit is contained in:
vfbd 2022-12-12 14:11:08 -05:00
parent eeb1774d42
commit 33ba3e7e27
1 changed files with 7 additions and 2 deletions

View File

@ -54,6 +54,7 @@ import numpy as np
import collections import collections
import _codecs import _codecs
import utils import utils
import os
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@ -93,12 +94,16 @@ class LazyTensor:
def __repr__(self): def __repr__(self):
return self.__view(repr) return self.__view(repr)
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor: def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
size = reduce(lambda x, y: x * y, self.shape, 1) size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype dtype = self.dtype
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
if isinstance(checkpoint, zipfile.ZipFile): if isinstance(checkpoint, zipfile.ZipFile):
f = checkpoint.open(f"archive/data/{self.key}", "r") try:
f = checkpoint.open(f"archive/data/{self.key}", "r")
except:
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
f.read(self.seek_offset) f.read(self.seek_offset)
else: else:
f = checkpoint f = checkpoint