Hello its breaking breakmodel time

This commit is contained in:
somebody
2023-05-27 16:31:53 -05:00
parent 97d2a78899
commit 1546b9efaa
8 changed files with 236 additions and 1097 deletions

View File

@@ -101,6 +101,7 @@ class TorchLazyTensor(LazyTensor):
stride: Optional[Tuple[int, ...]] = None,
requires_grad=False,
backward_hooks: Any = None,
file_handle: Any = None
):
self.storage_type = storage_type
self.key = key
@@ -111,6 +112,7 @@ class TorchLazyTensor(LazyTensor):
self.stride = stride
self.requires_grad = requires_grad
self.backward_hooks = backward_hooks
self.file_handle = file_handle
def __view(self, f: Callable):
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
@@ -120,11 +122,13 @@ class TorchLazyTensor(LazyTensor):
def materialize(
self,
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile],
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
map_location=None,
no_grad=True,
filename="pytorch_model.bin",
) -> torch.Tensor:
checkpoint = checkpoint or self.file_handle
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
size = reduce(lambda x, y: x * y, self.shape, 1)
dtype = self.dtype
@@ -237,6 +241,8 @@ class _LazyUnpickler(RestrictedUnpickler):
lazy_loaded_storages: Dict[str, LazyTensor]
def __init__(self, *args, **kwargs):
# print(args, kwargs)
self.file_handle = args[0]
self.lazy_loaded_storages = {}
return super().__init__(*args, **kwargs)
@@ -247,7 +253,7 @@ class _LazyUnpickler(RestrictedUnpickler):
typename == "storage"
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, _ = saved_id[1:]
return TorchLazyTensor(storage_type, key, location)
return TorchLazyTensor(storage_type, key, location, file_handle=self.file_handle)
def load(self, *args, **kwargs):
retval = super().load(*args, **kwargs)