mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Hello its breaking breakmodel time
This commit is contained in:
@@ -101,6 +101,7 @@ class TorchLazyTensor(LazyTensor):
|
||||
stride: Optional[Tuple[int, ...]] = None,
|
||||
requires_grad=False,
|
||||
backward_hooks: Any = None,
|
||||
file_handle: Any = None
|
||||
):
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
@@ -111,6 +112,7 @@ class TorchLazyTensor(LazyTensor):
|
||||
self.stride = stride
|
||||
self.requires_grad = requires_grad
|
||||
self.backward_hooks = backward_hooks
|
||||
self.file_handle = file_handle
|
||||
|
||||
def __view(self, f: Callable):
|
||||
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
|
||||
@@ -120,11 +122,13 @@ class TorchLazyTensor(LazyTensor):
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile],
|
||||
checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile] = None,
|
||||
map_location=None,
|
||||
no_grad=True,
|
||||
filename="pytorch_model.bin",
|
||||
) -> torch.Tensor:
|
||||
checkpoint = checkpoint or self.file_handle
|
||||
|
||||
filename = os.path.basename(os.path.normpath(filename)).split(".")[0]
|
||||
size = reduce(lambda x, y: x * y, self.shape, 1)
|
||||
dtype = self.dtype
|
||||
@@ -237,6 +241,8 @@ class _LazyUnpickler(RestrictedUnpickler):
|
||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# print(args, kwargs)
|
||||
self.file_handle = args[0]
|
||||
self.lazy_loaded_storages = {}
|
||||
return super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -247,7 +253,7 @@ class _LazyUnpickler(RestrictedUnpickler):
|
||||
typename == "storage"
|
||||
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, _ = saved_id[1:]
|
||||
return TorchLazyTensor(storage_type, key, location)
|
||||
return TorchLazyTensor(storage_type, key, location, file_handle=self.file_handle)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
retval = super().load(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user