mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Patch _rebuild_from_type_v2 to not try converting LazyTensors to Tensors
This commit is contained in:
@@ -57,6 +57,7 @@ import _codecs
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.storage import UntypedStorage
|
from torch.storage import UntypedStorage
|
||||||
|
|
||||||
@@ -237,6 +238,29 @@ class SafetensorsLazyTensor(LazyTensor):
|
|||||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
||||||
|
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
||||||
|
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
||||||
|
|
||||||
|
ret = func(*args)
|
||||||
|
|
||||||
|
# BEGIN PATCH
|
||||||
|
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
|
||||||
|
if type(ret) is not new_type and not transformation_ok:
|
||||||
|
# END PATCH
|
||||||
|
ret = ret.as_subclass(new_type)
|
||||||
|
|
||||||
|
# Tensor does define __setstate__ even though it doesn't define
|
||||||
|
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
||||||
|
# on Tensor
|
||||||
|
if (
|
||||||
|
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
||||||
|
is not Tensor.__setstate__
|
||||||
|
):
|
||||||
|
ret.__setstate__(state)
|
||||||
|
else:
|
||||||
|
ret = torch._utils._set_obj_state(ret, state)
|
||||||
|
return ret
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
def original_persistent_load(self, saved_id):
|
def original_persistent_load(self, saved_id):
|
||||||
@@ -253,7 +277,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||||
return torch._utils._rebuild_tensor_v2
|
return torch._utils._rebuild_tensor_v2
|
||||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||||
return torch._tensor._rebuild_from_type_v2
|
return _patched_rebuild_from_type_v2
|
||||||
elif module == "torch" and name in (
|
elif module == "torch" and name in (
|
||||||
"DoubleStorage",
|
"DoubleStorage",
|
||||||
"FloatStorage",
|
"FloatStorage",
|
||||||
|
Reference in New Issue
Block a user