Patch _rebuild_from_type_v2 to not try converting LazyTensors to Tensors

This commit is contained in:
somebody
2023-07-08 13:57:05 -05:00
parent 802929f5f2
commit fd6f66a98d

View File

@@ -57,6 +57,7 @@ import _codecs
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch import Tensor
from torch.nn import Module
from torch.storage import UntypedStorage
@@ -237,6 +238,29 @@ class SafetensorsLazyTensor(LazyTensor):
self.checkpoint_file, tensor_key=self.key, device=self.location
)
def _patched_rebuild_from_type_v2(func, new_type, args, state):
"""A patched version of torch._tensor._rebuild_from_type_v2 that
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
ret = func(*args)
# BEGIN PATCH
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
if type(ret) is not new_type and not transformation_ok:
# END PATCH
ret = ret.as_subclass(new_type)
# Tensor does define __setstate__ even though it doesn't define
# __getstate__. So only use __setstate__ if it is NOT the one defined
# on Tensor
if (
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
is not Tensor.__setstate__
):
ret.__setstate__(state)
else:
ret = torch._utils._set_obj_state(ret, state)
return ret
class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id):
@@ -253,7 +277,7 @@ class RestrictedUnpickler(pickle.Unpickler):
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return torch._tensor._rebuild_from_type_v2
return _patched_rebuild_from_type_v2
elif module == "torch" and name in (
"DoubleStorage",
"FloatStorage",