From 8e6e04be5f3e91baa3ac23d0f33c49bcc434c58f Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 3 Mar 2022 11:17:59 -0500 Subject: [PATCH] (torch_lazy_loader.py) Add dematerialized modules setting --- torch_lazy_loader.py | 155 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 1 deletion(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index 4a29d0c8..d9358442 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -1,11 +1,62 @@ +''' +This file is AGPL-licensed. + +Some of the code in this file is copied from PyTorch. + +The license for PyTorch is shown below: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +''' + + import contextlib from functools import reduce +import itertools import zipfile import pickle import torch +from torch.nn import Module from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + class LazyTensor: def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): self.storage_type = storage_type @@ -73,8 +124,77 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): return lazy_storage +# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 +def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}' + .format(key, type(input_param))) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + try: + with torch.no_grad(): + #param.copy_(input_param) + new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new + if name in self._parameters: # This line is new + self._parameters[name] = new_param # This line is new + if name in persistent_buffers: # This line is new + self._buffers[name] = new_param # This line is new + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + @contextlib.contextmanager -def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): +def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): if not enable: yield False return @@ -96,9 +216,42 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): torch.load = torch_load + def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): + retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + if callback is not None: + callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) + return retval + + torch.load = torch_load + + if dematerialized_modules: + old_linear_init = torch.nn.Linear.__init__ + old_embedding_init = torch.nn.Embedding.__init__ + old_layernorm_init = torch.nn.LayerNorm.__init__ + + def linear_init(self, *args, device=None, **kwargs): + return old_linear_init(self, *args, device="meta", **kwargs) + + def embedding_init(self, *args, device=None, **kwargs): + return old_embedding_init(self, *args, device="meta", **kwargs) + + def layernorm_init(self, *args, device=None, **kwargs): + return old_layernorm_init(self, *args, device="meta", **kwargs) + + torch.nn.Linear.__init__ = linear_init + torch.nn.Embedding.__init__ = embedding_init + torch.nn.LayerNorm.__init__ = layernorm_init + old_load_from_state_dict = torch.nn.Module._load_from_state_dict + torch.nn.Module._load_from_state_dict = _load_from_state_dict + yield True finally: pickle.Unpickler = old_unpickler torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load + if dematerialized_modules: + torch.nn.Linear.__init__ = old_linear_init + torch.nn.Embedding.__init__ = old_embedding_init + torch.nn.LayerNorm.__init__ = old_layernorm_init + torch.nn.Module._load_from_state_dict = old_load_from_state_dict