Patch _rebuild_from_type_v2 to not try converting LazyTensors to Tensors

This commit is contained in:
somebody
2023-07-08 13:57:05 -05:00
parent 802929f5f2
commit fd6f66a98d

View File

@@ -57,6 +57,7 @@ import _codecs
import os import os
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
@@ -237,6 +238,29 @@ class SafetensorsLazyTensor(LazyTensor):
self.checkpoint_file, tensor_key=self.key, device=self.location self.checkpoint_file, tensor_key=self.key, device=self.location
) )
def _patched_rebuild_from_type_v2(func, new_type, args, state):
"""A patched version of torch._tensor._rebuild_from_type_v2 that
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
ret = func(*args)
# BEGIN PATCH
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
if type(ret) is not new_type and not transformation_ok:
# END PATCH
ret = ret.as_subclass(new_type)
# Tensor does define __setstate__ even though it doesn't define
# __getstate__. So only use __setstate__ if it is NOT the one defined
# on Tensor
if (
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
is not Tensor.__setstate__
):
ret.__setstate__(state)
else:
ret = torch._utils._set_obj_state(ret, state)
return ret
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id): def original_persistent_load(self, saved_id):
@@ -253,7 +277,7 @@ class RestrictedUnpickler(pickle.Unpickler):
elif module == "torch._utils" and name == "_rebuild_tensor_v2": elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2 return torch._utils._rebuild_tensor_v2
elif module == "torch._tensor" and name == "_rebuild_from_type_v2": elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return torch._tensor._rebuild_from_type_v2 return _patched_rebuild_from_type_v2
elif module == "torch" and name in ( elif module == "torch" and name in (
"DoubleStorage", "DoubleStorage",
"FloatStorage", "FloatStorage",