Allow _rebuild_meta_tensor_no_storage

This commit is contained in:
Henk
2023-07-28 15:04:25 +02:00
parent 889fe8d548
commit 37babe1edd
2 changed files with 12 additions and 4 deletions

View File

@@ -1690,8 +1690,11 @@ class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "collections" and name == "OrderedDict":
return collections.OrderedDict
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2
elif module == "torch._utils" and name in (
"_rebuild_tensor_v2",
"_rebuild_meta_tensor_no_storage",
):
return getattr(torch._utils, name)
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return torch._tensor._rebuild_from_type_v2
elif module == "torch" and name in (
@@ -1706,6 +1709,7 @@ class RestrictedUnpickler(pickle.Unpickler):
"BoolStorage",
"BFloat16Storage",
"Tensor",
"float16",
):
return getattr(torch, name)
elif module == "numpy.core.multiarray" and name == "scalar":

View File

@@ -272,8 +272,11 @@ class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "collections" and name == "OrderedDict":
return collections.OrderedDict
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2
elif module == "torch._utils" and name in (
"_rebuild_tensor_v2",
"_rebuild_meta_tensor_no_storage",
):
return getattr(torch._utils, name)
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return _patched_rebuild_from_type_v2
elif module == "torch" and name in (
@@ -288,6 +291,7 @@ class RestrictedUnpickler(pickle.Unpickler):
"BoolStorage",
"BFloat16Storage",
"Tensor",
"float16",
):
return getattr(torch, name)
elif module == "numpy.core.multiarray" and name == "scalar":