From 33ba3e7e27546bc286f99868bc23607eafeb6bc3 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 12 Dec 2022 14:11:08 -0500 Subject: [PATCH] Fix materialize function for galactica models --- torch_lazy_loader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 1298335d..fae49e51 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -54,6 +54,7 @@ import numpy as np import collections import _codecs import utils +import os from torch.nn import Module from typing import Any, Callable, Dict, Optional, Tuple, Type, Union @@ -93,12 +94,16 @@ class LazyTensor: def __repr__(self): return self.__view(repr) - def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor: + def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor: + filename = os.path.basename(os.path.normpath(filename)).split('.')[0] size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): - f = checkpoint.open(f"archive/data/{self.key}", "r") + try: + f = checkpoint.open(f"archive/data/{self.key}", "r") + except: + f = checkpoint.open(f"{filename}/data/{self.key}", "r") f.read(self.seek_offset) else: f = checkpoint