mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	(torch_lazy_loader.py) Add dematerialized modules setting
This commit is contained in:
		| @@ -1,11 +1,62 @@ | |||||||
|  | ''' | ||||||
|  | This file is AGPL-licensed. | ||||||
|  |  | ||||||
|  | Some of the code in this file is copied from PyTorch. | ||||||
|  |  | ||||||
|  | The license for PyTorch is shown below: | ||||||
|  |  | ||||||
|  | Copyright (c) 2016-     Facebook, Inc            (Adam Paszke) | ||||||
|  | Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala) | ||||||
|  | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | ||||||
|  | Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu) | ||||||
|  | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | ||||||
|  | Copyright (c) 2011-2013 NYU                      (Clement Farabet) | ||||||
|  | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | ||||||
|  | Copyright (c) 2006      Idiap Research Institute (Samy Bengio) | ||||||
|  | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | ||||||
|  |  | ||||||
|  | Redistribution and use in source and binary forms, with or without | ||||||
|  | modification, are permitted provided that the following conditions are met: | ||||||
|  |  | ||||||
|  | 1. Redistributions of source code must retain the above copyright | ||||||
|  |    notice, this list of conditions and the following disclaimer. | ||||||
|  |  | ||||||
|  | 2. Redistributions in binary form must reproduce the above copyright | ||||||
|  |    notice, this list of conditions and the following disclaimer in the | ||||||
|  |    documentation and/or other materials provided with the distribution. | ||||||
|  |  | ||||||
|  | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America | ||||||
|  |    and IDIAP Research Institute nor the names of its contributors may be | ||||||
|  |    used to endorse or promote products derived from this software without | ||||||
|  |    specific prior written permission. | ||||||
|  |  | ||||||
|  | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||
|  | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||
|  | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||
|  | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | ||||||
|  | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||||
|  | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||||
|  | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||||
|  | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||||
|  | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||||
|  | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||||
|  | POSSIBILITY OF SUCH DAMAGE. | ||||||
|  | ''' | ||||||
|  |  | ||||||
|  |  | ||||||
| import contextlib | import contextlib | ||||||
| from functools import reduce | from functools import reduce | ||||||
|  | import itertools | ||||||
| import zipfile | import zipfile | ||||||
| import pickle | import pickle | ||||||
| import torch | import torch | ||||||
|  | from torch.nn import Module | ||||||
| from typing import Any, Callable, Dict, Optional, Tuple, Type, Union | from typing import Any, Callable, Dict, Optional, Tuple, Type, Union | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _EXTRA_STATE_KEY_SUFFIX = '_extra_state' | ||||||
|  |  | ||||||
|  |  | ||||||
| class LazyTensor: | class LazyTensor: | ||||||
|     def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): |     def __init__(self, storage_type: Type[torch._StorageBase], key: str, location: str, storage_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None): | ||||||
|         self.storage_type = storage_type |         self.storage_type = storage_type | ||||||
| @@ -73,8 +124,77 @@ def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride): | |||||||
|     return lazy_storage |     return lazy_storage | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438 | ||||||
|  | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | ||||||
|  |     for hook in self._load_state_dict_pre_hooks.values(): | ||||||
|  |         hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | ||||||
|  |  | ||||||
|  |     persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} | ||||||
|  |     local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) | ||||||
|  |     local_state = {k: v for k, v in local_name_params if v is not None} | ||||||
|  |  | ||||||
|  |     for name, param in local_state.items(): | ||||||
|  |         key = prefix + name | ||||||
|  |         if key in state_dict: | ||||||
|  |             input_param = state_dict[key] | ||||||
|  |             if not torch.overrides.is_tensor_like(input_param): | ||||||
|  |                 error_msgs.append('While copying the parameter named "{}", ' | ||||||
|  |                                     'expected torch.Tensor or Tensor-like object from checkpoint but ' | ||||||
|  |                                     'received {}' | ||||||
|  |                                     .format(key, type(input_param))) | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # This is used to avoid copying uninitialized parameters into | ||||||
|  |             # non-lazy modules, since they dont have the hook to do the checks | ||||||
|  |             # in such case, it will error when accessing the .shape attribute. | ||||||
|  |             is_param_lazy = torch.nn.parameter.is_lazy(param) | ||||||
|  |             # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ | ||||||
|  |             if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: | ||||||
|  |                 input_param = input_param[0] | ||||||
|  |  | ||||||
|  |             if not is_param_lazy and input_param.shape != param.shape: | ||||||
|  |                 # local shape should match the one in checkpoint | ||||||
|  |                 error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' | ||||||
|  |                                     'the shape in current model is {}.' | ||||||
|  |                                     .format(key, input_param.shape, param.shape)) | ||||||
|  |                 continue | ||||||
|  |             try: | ||||||
|  |                 with torch.no_grad(): | ||||||
|  |                     #param.copy_(input_param) | ||||||
|  |                     new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad)  # This line is new | ||||||
|  |                     if name in self._parameters:  # This line is new | ||||||
|  |                         self._parameters[name] = new_param  # This line is new | ||||||
|  |                     if name in persistent_buffers:  # This line is new | ||||||
|  |                         self._buffers[name] = new_param  # This line is new | ||||||
|  |             except Exception as ex: | ||||||
|  |                 error_msgs.append('While copying the parameter named "{}", ' | ||||||
|  |                                     'whose dimensions in the model are {} and ' | ||||||
|  |                                     'whose dimensions in the checkpoint are {}, ' | ||||||
|  |                                     'an exception occurred : {}.' | ||||||
|  |                                     .format(key, param.size(), input_param.size(), ex.args)) | ||||||
|  |         elif strict: | ||||||
|  |             missing_keys.append(key) | ||||||
|  |  | ||||||
|  |     extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX | ||||||
|  |     if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: | ||||||
|  |         if extra_state_key in state_dict: | ||||||
|  |             self.set_extra_state(state_dict[extra_state_key]) | ||||||
|  |         elif strict: | ||||||
|  |             missing_keys.append(extra_state_key) | ||||||
|  |     elif strict and (extra_state_key in state_dict): | ||||||
|  |         unexpected_keys.append(extra_state_key) | ||||||
|  |  | ||||||
|  |     if strict: | ||||||
|  |         for key in state_dict.keys(): | ||||||
|  |             if key.startswith(prefix) and key != extra_state_key: | ||||||
|  |                 input_name = key[len(prefix):] | ||||||
|  |                 input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child | ||||||
|  |                 if input_name not in self._modules and input_name not in local_state: | ||||||
|  |                     unexpected_keys.append(key) | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextlib.contextmanager | @contextlib.contextmanager | ||||||
| def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): | def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False): | ||||||
|     if not enable: |     if not enable: | ||||||
|         yield False |         yield False | ||||||
|         return |         return | ||||||
| @@ -96,9 +216,42 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None): | |||||||
|  |  | ||||||
|         torch.load = torch_load |         torch.load = torch_load | ||||||
|  |  | ||||||
|  |         def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args): | ||||||
|  |             retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) | ||||||
|  |             if callback is not None: | ||||||
|  |                 callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args) | ||||||
|  |             return retval | ||||||
|  |  | ||||||
|  |         torch.load = torch_load | ||||||
|  |  | ||||||
|  |         if dematerialized_modules: | ||||||
|  |             old_linear_init = torch.nn.Linear.__init__ | ||||||
|  |             old_embedding_init = torch.nn.Embedding.__init__ | ||||||
|  |             old_layernorm_init = torch.nn.LayerNorm.__init__ | ||||||
|  |  | ||||||
|  |             def linear_init(self, *args, device=None, **kwargs): | ||||||
|  |                 return old_linear_init(self, *args, device="meta", **kwargs) | ||||||
|  |  | ||||||
|  |             def embedding_init(self, *args, device=None, **kwargs): | ||||||
|  |                 return old_embedding_init(self, *args, device="meta", **kwargs) | ||||||
|  |  | ||||||
|  |             def layernorm_init(self, *args, device=None, **kwargs): | ||||||
|  |                 return old_layernorm_init(self, *args, device="meta", **kwargs) | ||||||
|  |  | ||||||
|  |             torch.nn.Linear.__init__ = linear_init | ||||||
|  |             torch.nn.Embedding.__init__ = embedding_init | ||||||
|  |             torch.nn.LayerNorm.__init__ = layernorm_init | ||||||
|  |             old_load_from_state_dict = torch.nn.Module._load_from_state_dict | ||||||
|  |             torch.nn.Module._load_from_state_dict = _load_from_state_dict | ||||||
|  |  | ||||||
|         yield True |         yield True | ||||||
|  |  | ||||||
|     finally: |     finally: | ||||||
|         pickle.Unpickler = old_unpickler |         pickle.Unpickler = old_unpickler | ||||||
|         torch._utils._rebuild_tensor = old_rebuild_tensor |         torch._utils._rebuild_tensor = old_rebuild_tensor | ||||||
|         torch.load = old_torch_load |         torch.load = old_torch_load | ||||||
|  |         if dematerialized_modules: | ||||||
|  |             torch.nn.Linear.__init__ = old_linear_init | ||||||
|  |             torch.nn.Embedding.__init__ = old_embedding_init | ||||||
|  |             torch.nn.LayerNorm.__init__ = old_layernorm_init | ||||||
|  |             torch.nn.Module._load_from_state_dict = old_load_from_state_dict | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann