diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index da2ec989..2af0ae51 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -57,6 +57,7 @@ import _codecs import os from typing import Any, Callable, Dict, Optional, Tuple, Type +from torch import Tensor from torch.nn import Module from torch.storage import UntypedStorage @@ -237,6 +238,29 @@ class SafetensorsLazyTensor(LazyTensor): self.checkpoint_file, tensor_key=self.key, device=self.location ) +def _patched_rebuild_from_type_v2(func, new_type, args, state): + """A patched version of torch._tensor._rebuild_from_type_v2 that + does not attempt to convert `LazyTensor`s to `torch.Tensor`s.""" + + ret = func(*args) + + # BEGIN PATCH + transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor + if type(ret) is not new_type and not transformation_ok: + # END PATCH + ret = ret.as_subclass(new_type) + + # Tensor does define __setstate__ even though it doesn't define + # __getstate__. So only use __setstate__ if it is NOT the one defined + # on Tensor + if ( + getattr(ret.__class__, "__setstate__", Tensor.__setstate__) + is not Tensor.__setstate__ + ): + ret.__setstate__(state) + else: + ret = torch._utils._set_obj_state(ret, state) + return ret class RestrictedUnpickler(pickle.Unpickler): def original_persistent_load(self, saved_id): @@ -253,7 +277,7 @@ class RestrictedUnpickler(pickle.Unpickler): elif module == "torch._utils" and name == "_rebuild_tensor_v2": return torch._utils._rebuild_tensor_v2 elif module == "torch._tensor" and name == "_rebuild_from_type_v2": - return torch._tensor._rebuild_from_type_v2 + return _patched_rebuild_from_type_v2 elif module == "torch" and name in ( "DoubleStorage", "FloatStorage",