Custom unpickler to avoid pickle's arbitrary code execution vulnerability

This commit is contained in:
vfbd
2022-10-06 20:08:08 -04:00
parent b85d74f22c
commit 323f593a96
2 changed files with 74 additions and 11 deletions

View File

@ -955,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
import torch
import torch.utils.dlpack
import torch_lazy_loader
from tqdm.auto import tqdm
move_xmap = jax.experimental.maps.xmap(
@ -996,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
continue
layer = checkpoint_layer - 2
shards = []
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
for checkpoint_shard in range(checkpoint_shards):
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
for key in shards[0]:
if key == "attention.rotary_emb.inv_freq":
continue