Merge pull request #235 from VE-FORBRYDERNE/patch
Fix materialize function for galactica models
This commit is contained in:
commit
0a926e41e4
|
@ -54,6 +54,7 @@ import numpy as np
|
||||||
import collections
|
import collections
|
||||||
import _codecs
|
import _codecs
|
||||||
import utils
|
import utils
|
||||||
|
import os
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
@ -93,12 +94,16 @@ class LazyTensor:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__view(repr)
|
return self.__view(repr)
|
||||||
|
|
||||||
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor:
|
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
|
||||||
|
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
dtype = self.dtype
|
dtype = self.dtype
|
||||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||||
if isinstance(checkpoint, zipfile.ZipFile):
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
try:
|
||||||
|
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||||
|
except:
|
||||||
|
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
|
||||||
f.read(self.seek_offset)
|
f.read(self.seek_offset)
|
||||||
else:
|
else:
|
||||||
f = checkpoint
|
f = checkpoint
|
||||||
|
|
Loading…
Reference in New Issue