From 323f593a9662ebef59e3fb94e2339a4c8066d14e Mon Sep 17 00:00:00 2001 From: vfbd Date: Thu, 6 Oct 2022 20:08:08 -0400 Subject: [PATCH] Custom unpickler to avoid pickle's arbitrary code execution vulnerability --- torch_lazy_loader.py | 79 +++++++++++++++++++++++++++++++++++++++----- tpu_mtj_backend.py | 6 ++-- 2 files changed, 74 insertions(+), 11 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 9e411261..1298335d 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -50,9 +50,12 @@ import itertools import zipfile import pickle import torch +import numpy as np +import collections +import _codecs import utils from torch.nn import Module -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union _EXTRA_STATE_KEY_SUFFIX = '_extra_state' @@ -111,8 +114,50 @@ class LazyTensor: tensor._backward_hooks = self.backward_hooks return tensor +class RestrictedUnpickler(pickle.Unpickler): + def original_persistent_load(self, saved_id): + return super().persistent_load(saved_id) -class _LazyUnpickler(pickle.Unpickler): + def forced_persistent_load(self, saved_id): + if saved_id[0] != "storage": + raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'") + return self.original_persistent_load(saved_id) + + def find_class(self, module, name): + if module == "collections" and name == "OrderedDict": + return collections.OrderedDict + elif module == "torch._utils" and name == "_rebuild_tensor_v2": + return torch._utils._rebuild_tensor_v2 + elif module == "torch" and name in ( + "DoubleStorage", + "FloatStorage", + "HalfStorage", + "LongStorage", + "IntStorage", + "ShortStorage", + "CharStorage", + "ByteStorage", + "BoolStorage", + "BFloat16Storage", + ): + return getattr(torch, name) + elif module == "numpy.core.multiarray" and name == "scalar": + return np.core.multiarray.scalar + elif module == "numpy" and name == "dtype": + return np.dtype + elif module == "_codecs" and name == "encode": + return _codecs.encode + else: + # Forbid everything else. + qualified_name = name if module == "__builtin__" else f"{module}.{name}" + raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code") + + def load(self, *args, **kwargs): + self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load) + self.persistent_load = self.forced_persistent_load + return super().load(*args, **kwargs) + +class _LazyUnpickler(RestrictedUnpickler): lazy_loaded_storages: Dict[str, LazyTensor] def __init__(self, *args, **kwargs): @@ -127,7 +172,6 @@ class _LazyUnpickler(pickle.Unpickler): return LazyTensor(storage_type, key, location) def load(self, *args, **kwargs): - self.persistent_load = self.forced_persistent_load retval = super().load(*args, **kwargs) self.lazy_loaded_storages = {} return retval @@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss unexpected_keys.append(key) +@contextlib.contextmanager +def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler): + try: + old_unpickler = pickle.Unpickler + pickle.Unpickler = unpickler + + old_pickle_load = pickle.load + + def new_pickle_load(*args, **kwargs): + return pickle.Unpickler(*args, **kwargs).load() + + pickle.load = new_pickle_load + + yield + + finally: + pickle.Unpickler = old_unpickler + pickle.load = old_pickle_load + @contextlib.contextmanager def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False): if not enable: - yield False + with use_custom_unpickler(RestrictedUnpickler): + yield False return try: - old_unpickler = pickle.Unpickler - pickle.Unpickler = _LazyUnpickler - old_rebuild_tensor = torch._utils._rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor @@ -261,10 +322,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate old_load_from_state_dict = torch.nn.Module._load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict - yield True + with use_custom_unpickler(_LazyUnpickler): + yield True finally: - pickle.Unpickler = old_unpickler torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load if dematerialized_modules: diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 2642943b..d992ba45 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -955,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): import torch import torch.utils.dlpack + import torch_lazy_loader from tqdm.auto import tqdm move_xmap = jax.experimental.maps.xmap( @@ -996,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): continue layer = checkpoint_layer - 2 shards = [] - for checkpoint_shard in range(checkpoint_shards): - shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) + with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler): + for checkpoint_shard in range(checkpoint_shards): + shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) for key in shards[0]: if key == "attention.rotary_emb.inv_freq": continue