mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-24 22:19:10 +01:00
(torch_lazy_loader.py) Add support for materializing from a ZipExtFile
This commit is contained in:
parent
c338b52d68
commit
1ecc452dc8
@ -3,7 +3,7 @@ from functools import reduce
|
|||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
|
||||||
class LazyTensor:
|
class LazyTensor:
|
||||||
@ -23,13 +23,20 @@ class LazyTensor:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__view(repr)
|
return self.__view(repr)
|
||||||
|
|
||||||
def materialize(self, checkpoint: zipfile.ZipFile, map_location=None) -> torch.Tensor:
|
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None) -> torch.Tensor:
|
||||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||||
dtype = self.storage_type(0).dtype
|
dtype = self.storage_type(0).dtype
|
||||||
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
|
||||||
with checkpoint.open(f"archive/data/{self.key}", "r") as f:
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
|
f = checkpoint.open(f"archive/data/{self.key}", "r")
|
||||||
f.seek(self.storage_offset)
|
f.seek(self.storage_offset)
|
||||||
|
else:
|
||||||
|
f = checkpoint
|
||||||
|
try:
|
||||||
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
|
storage = self.storage_type.from_buffer(f.read(nbytes), "little")
|
||||||
|
finally:
|
||||||
|
if isinstance(checkpoint, zipfile.ZipFile):
|
||||||
|
f.close()
|
||||||
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
|
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
|
||||||
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
|
||||||
tensor.set_(storage, 0, self.shape, self.stride)
|
tensor.set_(storage, 0, self.shape, self.stride)
|
||||||
|
Loading…
Reference in New Issue
Block a user