mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Patch _rebuild_from_type_v2 to not try converting LazyTensors to Tensors
This commit is contained in:
@@ -57,6 +57,7 @@ import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.storage import UntypedStorage
|
||||
|
||||
@@ -237,6 +238,29 @@ class SafetensorsLazyTensor(LazyTensor):
|
||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||
)
|
||||
|
||||
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
||||
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
||||
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
||||
|
||||
ret = func(*args)
|
||||
|
||||
# BEGIN PATCH
|
||||
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
|
||||
if type(ret) is not new_type and not transformation_ok:
|
||||
# END PATCH
|
||||
ret = ret.as_subclass(new_type)
|
||||
|
||||
# Tensor does define __setstate__ even though it doesn't define
|
||||
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
||||
# on Tensor
|
||||
if (
|
||||
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
||||
is not Tensor.__setstate__
|
||||
):
|
||||
ret.__setstate__(state)
|
||||
else:
|
||||
ret = torch._utils._set_obj_state(ret, state)
|
||||
return ret
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def original_persistent_load(self, saved_id):
|
||||
@@ -253,7 +277,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||
return torch._utils._rebuild_tensor_v2
|
||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||
return torch._tensor._rebuild_from_type_v2
|
||||
return _patched_rebuild_from_type_v2
|
||||
elif module == "torch" and name in (
|
||||
"DoubleStorage",
|
||||
"FloatStorage",
|
||||
|
Reference in New Issue
Block a user