Custom unpickler to avoid pickle's arbitrary code execution vulnerability
This commit is contained in:
parent
b85d74f22c
commit
323f593a96
|
@ -50,9 +50,12 @@ import itertools
|
||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import _codecs
|
||||||
import utils
|
import utils
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
|
||||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||||
|
@ -111,8 +114,50 @@ class LazyTensor:
|
||||||
tensor._backward_hooks = self.backward_hooks
|
tensor._backward_hooks = self.backward_hooks
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def original_persistent_load(self, saved_id):
|
||||||
|
return super().persistent_load(saved_id)
|
||||||
|
|
||||||
class _LazyUnpickler(pickle.Unpickler):
|
def forced_persistent_load(self, saved_id):
|
||||||
|
if saved_id[0] != "storage":
|
||||||
|
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
||||||
|
return self.original_persistent_load(saved_id)
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == "collections" and name == "OrderedDict":
|
||||||
|
return collections.OrderedDict
|
||||||
|
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||||
|
return torch._utils._rebuild_tensor_v2
|
||||||
|
elif module == "torch" and name in (
|
||||||
|
"DoubleStorage",
|
||||||
|
"FloatStorage",
|
||||||
|
"HalfStorage",
|
||||||
|
"LongStorage",
|
||||||
|
"IntStorage",
|
||||||
|
"ShortStorage",
|
||||||
|
"CharStorage",
|
||||||
|
"ByteStorage",
|
||||||
|
"BoolStorage",
|
||||||
|
"BFloat16Storage",
|
||||||
|
):
|
||||||
|
return getattr(torch, name)
|
||||||
|
elif module == "numpy.core.multiarray" and name == "scalar":
|
||||||
|
return np.core.multiarray.scalar
|
||||||
|
elif module == "numpy" and name == "dtype":
|
||||||
|
return np.dtype
|
||||||
|
elif module == "_codecs" and name == "encode":
|
||||||
|
return _codecs.encode
|
||||||
|
else:
|
||||||
|
# Forbid everything else.
|
||||||
|
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||||
|
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
|
||||||
|
self.persistent_load = self.forced_persistent_load
|
||||||
|
return super().load(*args, **kwargs)
|
||||||
|
|
||||||
|
class _LazyUnpickler(RestrictedUnpickler):
|
||||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -127,7 +172,6 @@ class _LazyUnpickler(pickle.Unpickler):
|
||||||
return LazyTensor(storage_type, key, location)
|
return LazyTensor(storage_type, key, location)
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
def load(self, *args, **kwargs):
|
||||||
self.persistent_load = self.forced_persistent_load
|
|
||||||
retval = super().load(*args, **kwargs)
|
retval = super().load(*args, **kwargs)
|
||||||
self.lazy_loaded_storages = {}
|
self.lazy_loaded_storages = {}
|
||||||
return retval
|
return retval
|
||||||
|
@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||||
unexpected_keys.append(key)
|
unexpected_keys.append(key)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||||
|
try:
|
||||||
|
old_unpickler = pickle.Unpickler
|
||||||
|
pickle.Unpickler = unpickler
|
||||||
|
|
||||||
|
old_pickle_load = pickle.load
|
||||||
|
|
||||||
|
def new_pickle_load(*args, **kwargs):
|
||||||
|
return pickle.Unpickler(*args, **kwargs).load()
|
||||||
|
|
||||||
|
pickle.load = new_pickle_load
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pickle.Unpickler = old_unpickler
|
||||||
|
pickle.load = old_pickle_load
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||||
if not enable:
|
if not enable:
|
||||||
yield False
|
with use_custom_unpickler(RestrictedUnpickler):
|
||||||
|
yield False
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
old_unpickler = pickle.Unpickler
|
|
||||||
pickle.Unpickler = _LazyUnpickler
|
|
||||||
|
|
||||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||||
|
|
||||||
|
@ -261,10 +322,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
||||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||||
|
|
||||||
yield True
|
with use_custom_unpickler(_LazyUnpickler):
|
||||||
|
yield True
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
pickle.Unpickler = old_unpickler
|
|
||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
|
|
|
@ -955,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.dlpack
|
import torch.utils.dlpack
|
||||||
|
import torch_lazy_loader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
move_xmap = jax.experimental.maps.xmap(
|
move_xmap = jax.experimental.maps.xmap(
|
||||||
|
@ -996,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
continue
|
continue
|
||||||
layer = checkpoint_layer - 2
|
layer = checkpoint_layer - 2
|
||||||
shards = []
|
shards = []
|
||||||
for checkpoint_shard in range(checkpoint_shards):
|
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
|
||||||
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
for checkpoint_shard in range(checkpoint_shards):
|
||||||
|
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
||||||
for key in shards[0]:
|
for key in shards[0]:
|
||||||
if key == "attention.rotary_emb.inv_freq":
|
if key == "attention.rotary_emb.inv_freq":
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue