Custom unpickler to avoid pickle's arbitrary code execution vulnerability

This commit is contained in:
vfbd 2022-10-06 20:08:08 -04:00
parent b85d74f22c
commit 323f593a96
2 changed files with 74 additions and 11 deletions

View File

@ -50,9 +50,12 @@ import itertools
import zipfile import zipfile
import pickle import pickle
import torch import torch
import numpy as np
import collections
import _codecs
import utils import utils
from torch.nn import Module from torch.nn import Module
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@ -111,8 +114,50 @@ class LazyTensor:
tensor._backward_hooks = self.backward_hooks tensor._backward_hooks = self.backward_hooks
return tensor return tensor
class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id):
return super().persistent_load(saved_id)
class _LazyUnpickler(pickle.Unpickler): def forced_persistent_load(self, saved_id):
if saved_id[0] != "storage":
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
return self.original_persistent_load(saved_id)
def find_class(self, module, name):
if module == "collections" and name == "OrderedDict":
return collections.OrderedDict
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2
elif module == "torch" and name in (
"DoubleStorage",
"FloatStorage",
"HalfStorage",
"LongStorage",
"IntStorage",
"ShortStorage",
"CharStorage",
"ByteStorage",
"BoolStorage",
"BFloat16Storage",
):
return getattr(torch, name)
elif module == "numpy.core.multiarray" and name == "scalar":
return np.core.multiarray.scalar
elif module == "numpy" and name == "dtype":
return np.dtype
elif module == "_codecs" and name == "encode":
return _codecs.encode
else:
# Forbid everything else.
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
def load(self, *args, **kwargs):
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
self.persistent_load = self.forced_persistent_load
return super().load(*args, **kwargs)
class _LazyUnpickler(RestrictedUnpickler):
lazy_loaded_storages: Dict[str, LazyTensor] lazy_loaded_storages: Dict[str, LazyTensor]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -127,7 +172,6 @@ class _LazyUnpickler(pickle.Unpickler):
return LazyTensor(storage_type, key, location) return LazyTensor(storage_type, key, location)
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
self.persistent_load = self.forced_persistent_load
retval = super().load(*args, **kwargs) retval = super().load(*args, **kwargs)
self.lazy_loaded_storages = {} self.lazy_loaded_storages = {}
return retval return retval
@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
unexpected_keys.append(key) unexpected_keys.append(key)
@contextlib.contextmanager
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = unpickler
old_pickle_load = pickle.load
def new_pickle_load(*args, **kwargs):
return pickle.Unpickler(*args, **kwargs).load()
pickle.load = new_pickle_load
yield
finally:
pickle.Unpickler = old_unpickler
pickle.load = old_pickle_load
@contextlib.contextmanager @contextlib.contextmanager
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False): def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
if not enable: if not enable:
yield False with use_custom_unpickler(RestrictedUnpickler):
yield False
return return
try: try:
old_unpickler = pickle.Unpickler
pickle.Unpickler = _LazyUnpickler
old_rebuild_tensor = torch._utils._rebuild_tensor old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor
@ -261,10 +322,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
old_load_from_state_dict = torch.nn.Module._load_from_state_dict old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict
yield True with use_custom_unpickler(_LazyUnpickler):
yield True
finally: finally:
pickle.Unpickler = old_unpickler
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load torch.load = old_torch_load
if dematerialized_modules: if dematerialized_modules:

View File

@ -955,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
import torch import torch
import torch.utils.dlpack import torch.utils.dlpack
import torch_lazy_loader
from tqdm.auto import tqdm from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap( move_xmap = jax.experimental.maps.xmap(
@ -996,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
continue continue
layer = checkpoint_layer - 2 layer = checkpoint_layer - 2
shards = [] shards = []
for checkpoint_shard in range(checkpoint_shards): with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu")) for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]: for key in shards[0]:
if key == "attention.rotary_emb.inv_freq": if key == "attention.rotary_emb.inv_freq":
continue continue