import contextlib from functools import reduce import itertools import zipfile import pickle import torch import utils from torch.nn import Module from typing import Any, Callable, Dict, Optional, Tuple, Union _EXTRA_STATE_KEY_SUFFIX = '_extra_state' STORAGE_TYPE_MAP = { torch.float64: torch.DoubleStorage, torch.float32: torch.FloatStorage, torch.float16: torch.HalfStorage, torch.int64: torch.LongStorage, torch.int32: torch.IntStorage, torch.int16: torch.ShortStorage, torch.int8: torch.CharStorage, torch.uint8: torch.ByteStorage, torch.bool: torch.BoolStorage, torch.bfloat16: torch.BFloat16Storage, } class LazyTensor: def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type self.key = key self.location = location self.dtype = dtype self.seek_offset = seek_offset self.shape = shape self.stride = stride self.requires_grad = requires_grad self.backward_hooks = backward_hooks def __view(self, f: Callable): return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})" def __repr__(self): return self.__view(repr) def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True) -> torch.Tensor: size = reduce(lambda x, y: x * y, self.shape, 1) dtype = self.dtype nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) if isinstance(checkpoint, zipfile.ZipFile): f = checkpoint.open(f"archive/data/{self.key}", "r") f.read(self.seek_offset) else: f = checkpoint try: storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little") finally: if isinstance(checkpoint, zipfile.ZipFile): f.close() storage = torch.serialization._get_restore_location(map_location)(storage, self.location) tensor = torch.tensor([], dtype=storage.dtype, device=storage.device) tensor.set_(storage, 0, self.shape, self.stride) tensor.requires_grad = not no_grad and self.requires_grad tensor._backward_hooks = self.backward_hooks return tensor class _LazyUnpickler(pickle.Unpickler): lazy_loaded_storages: Dict[str, LazyTensor] def __init__(self, *args, **kwargs): self.lazy_loaded_storages = {} return super().__init__(*args, **kwargs) def forced_persistent_load(self, saved_id): assert isinstance(saved_id, tuple) typename = saved_id[0] assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" storage_type, key, location, _ = saved_id[1:] return LazyTensor(storage_type, key, location) def load(self, *args, **kwargs): self.persistent_load = self.forced_persistent_load retval = super().load(*args, **kwargs) self.lazy_loaded_storages = {} return retval def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): lazy_storage.shape = shape lazy_storage.stride = stride dtype = lazy_storage.storage_type.dtype if not isinstance(dtype, torch.dtype): dtype = lazy_storage.storage_type(0).dtype lazy_storage.dtype = dtype lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3) return lazy_storage # Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name if key in state_dict: input_param = state_dict[key] if not torch.overrides.is_tensor_like(input_param): error_msgs.append('While copying the parameter named "{}", ' 'expected torch.Tensor or Tensor-like object from checkpoint but ' 'received {}' .format(key, type(input_param))) continue # This is used to avoid copying uninitialized parameters into # non-lazy modules, since they dont have the hook to do the checks # in such case, it will error when accessing the .shape attribute. is_param_lazy = torch.nn.parameter.is_lazy(param) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue try: with torch.no_grad(): #param.copy_(input_param) new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new if name in self._parameters: # This line is new self._parameters[name] = new_param # This line is new if name in persistent_buffers: # This line is new self._buffers[name] = new_param # This line is new except Exception as ex: error_msgs.append('While copying the parameter named "{}", ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}, ' 'an exception occurred : {}.' .format(key, param.size(), input_param.size(), ex.args)) elif strict: missing_keys.append(key) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: missing_keys.append(extra_state_key) elif strict and (extra_state_key in state_dict): unexpected_keys.append(extra_state_key) if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: input_name = key[len(prefix):] input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) @contextlib.contextmanager def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False): if not enable: yield False return try: old_unpickler = pickle.Unpickler pickle.Unpickler = _LazyUnpickler old_rebuild_tensor = torch._utils._rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor old_torch_load = torch.load def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) if callback is not None: callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) return retval torch.load = torch_load if dematerialized_modules: if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE: import accelerate init_empty_weights = accelerate.init_empty_weights() init_empty_weights.__enter__() else: old_linear_init = torch.nn.Linear.__init__ old_embedding_init = torch.nn.Embedding.__init__ old_layernorm_init = torch.nn.LayerNorm.__init__ def linear_init(self, *args, device=None, **kwargs): return old_linear_init(self, *args, device="meta", **kwargs) def embedding_init(self, *args, device=None, **kwargs): return old_embedding_init(self, *args, device="meta", **kwargs) def layernorm_init(self, *args, device=None, **kwargs): return old_layernorm_init(self, *args, device="meta", **kwargs) torch.nn.Linear.__init__ = linear_init torch.nn.Embedding.__init__ = embedding_init torch.nn.LayerNorm.__init__ = layernorm_init old_load_from_state_dict = torch.nn.Module._load_from_state_dict torch.nn.Module._load_from_state_dict = _load_from_state_dict yield True finally: pickle.Unpickler = old_unpickler torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load if dematerialized_modules: if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE: init_empty_weights.__exit__(None, None, None) else: torch.nn.Linear.__init__ = old_linear_init torch.nn.Embedding.__init__ = old_embedding_init torch.nn.LayerNorm.__init__ = old_layernorm_init torch.nn.Module._load_from_state_dict = old_load_from_state_dict