2022-03-03 17:17:59 +01:00
'''
This file is AGPL - licensed .
Some of the code in this file is copied from PyTorch .
The license for PyTorch is shown below :
Copyright ( c ) 2016 - Facebook , Inc ( Adam Paszke )
Copyright ( c ) 2014 - Facebook , Inc ( Soumith Chintala )
Copyright ( c ) 2011 - 2014 Idiap Research Institute ( Ronan Collobert )
Copyright ( c ) 2012 - 2014 Deepmind Technologies ( Koray Kavukcuoglu )
Copyright ( c ) 2011 - 2012 NEC Laboratories America ( Koray Kavukcuoglu )
Copyright ( c ) 2011 - 2013 NYU ( Clement Farabet )
Copyright ( c ) 2006 - 2010 NEC Laboratories America ( Ronan Collobert , Leon Bottou , Iain Melvin , Jason Weston )
Copyright ( c ) 2006 Idiap Research Institute ( Samy Bengio )
Copyright ( c ) 2001 - 2004 Idiap Research Institute ( Ronan Collobert , Samy Bengio , Johnny Mariethoz )
Redistribution and use in source and binary forms , with or without
modification , are permitted provided that the following conditions are met :
1. Redistributions of source code must retain the above copyright
notice , this list of conditions and the following disclaimer .
2. Redistributions in binary form must reproduce the above copyright
notice , this list of conditions and the following disclaimer in the
documentation and / or other materials provided with the distribution .
3. Neither the names of Facebook , Deepmind Technologies , NYU , NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission .
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS " AS IS "
AND ANY EXPRESS OR IMPLIED WARRANTIES , INCLUDING , BUT NOT LIMITED TO , THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED . IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT , INDIRECT , INCIDENTAL , SPECIAL , EXEMPLARY , OR
CONSEQUENTIAL DAMAGES ( INCLUDING , BUT NOT LIMITED TO , PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES ; LOSS OF USE , DATA , OR PROFITS ; OR BUSINESS
INTERRUPTION ) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY , WHETHER IN
CONTRACT , STRICT LIABILITY , OR TORT ( INCLUDING NEGLIGENCE OR OTHERWISE )
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE , EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE .
'''
2022-03-01 21:40:44 +01:00
import contextlib
2022-03-02 07:02:35 +01:00
from functools import reduce
2022-03-03 17:17:59 +01:00
import itertools
2022-03-02 07:02:35 +01:00
import zipfile
2022-03-01 21:40:44 +01:00
import pickle
import torch
2022-03-03 17:17:59 +01:00
from torch . nn import Module
2022-07-12 22:48:01 +02:00
from typing import Any , Callable , Dict , Optional , Tuple , Union
2022-03-01 21:40:44 +01:00
2022-03-03 17:17:59 +01:00
_EXTRA_STATE_KEY_SUFFIX = ' _extra_state '
2022-03-17 17:51:41 +01:00
STORAGE_TYPE_MAP = {
torch . float64 : torch . DoubleStorage ,
torch . float32 : torch . FloatStorage ,
torch . float16 : torch . HalfStorage ,
torch . int64 : torch . LongStorage ,
torch . int32 : torch . IntStorage ,
torch . int16 : torch . ShortStorage ,
torch . int8 : torch . CharStorage ,
torch . uint8 : torch . ByteStorage ,
torch . bool : torch . BoolStorage ,
torch . bfloat16 : torch . BFloat16Storage ,
}
2022-03-01 21:40:44 +01:00
class LazyTensor :
2022-07-12 22:48:01 +02:00
def __init__ ( self , storage_type , key : str , location : str , dtype : Optional [ torch . dtype ] = None , seek_offset : Optional [ int ] = None , shape : Optional [ Tuple [ int , . . . ] ] = None , stride : Optional [ Tuple [ int , . . . ] ] = None , requires_grad = False , backward_hooks : Any = None ) :
2022-03-01 21:40:44 +01:00
self . storage_type = storage_type
self . key = key
self . location = location
2022-03-17 17:51:41 +01:00
self . dtype = dtype
2022-03-04 05:53:40 +01:00
self . seek_offset = seek_offset
2022-03-01 21:40:44 +01:00
self . shape = shape
self . stride = stride
self . requires_grad = requires_grad
self . backward_hooks = backward_hooks
def __view ( self , f : Callable ) :
2022-03-17 17:51:41 +01:00
return f " { type ( self ) . __name__ } (storage_type= { f ( self . storage_type ) } , key= { f ( self . key ) } , location= { f ( self . location ) } , dtype= { f ( self . dtype ) } , seek_offset= { f ( self . seek_offset ) } , shape= { f ( self . shape ) } , stride= { f ( self . stride ) } , requires_grad= { f ( self . requires_grad ) } , backward_hooks= { f ( self . backward_hooks ) } ) "
2022-03-01 21:40:44 +01:00
def __repr__ ( self ) :
return self . __view ( repr )
2022-04-19 18:26:02 +02:00
def materialize ( self , checkpoint : Union [ zipfile . ZipFile , zipfile . ZipExtFile ] , map_location = None , no_grad = True ) - > torch . Tensor :
2022-03-02 07:02:35 +01:00
size = reduce ( lambda x , y : x * y , self . shape , 1 )
2022-03-17 17:51:41 +01:00
dtype = self . dtype
2022-03-02 07:02:35 +01:00
nbytes = size if dtype is torch . bool else size * ( ( torch . finfo if dtype . is_floating_point else torch . iinfo ) ( dtype ) . bits >> 3 )
2022-03-02 19:08:21 +01:00
if isinstance ( checkpoint , zipfile . ZipFile ) :
f = checkpoint . open ( f " archive/data/ { self . key } " , " r " )
2022-04-02 21:02:54 +02:00
f . read ( self . seek_offset )
2022-03-02 19:08:21 +01:00
else :
f = checkpoint
try :
2022-03-17 17:51:41 +01:00
storage = STORAGE_TYPE_MAP [ dtype ] . from_buffer ( f . read ( nbytes ) , " little " )
2022-03-02 19:08:21 +01:00
finally :
if isinstance ( checkpoint , zipfile . ZipFile ) :
f . close ( )
2022-03-01 21:40:44 +01:00
storage = torch . serialization . _get_restore_location ( map_location ) ( storage , self . location )
tensor = torch . tensor ( [ ] , dtype = storage . dtype , device = storage . device )
2022-03-02 07:02:35 +01:00
tensor . set_ ( storage , 0 , self . shape , self . stride )
2022-04-19 18:26:02 +02:00
tensor . requires_grad = not no_grad and self . requires_grad
2022-03-01 21:40:44 +01:00
tensor . _backward_hooks = self . backward_hooks
return tensor
class _LazyUnpickler ( pickle . Unpickler ) :
lazy_loaded_storages : Dict [ str , LazyTensor ]
def __init__ ( self , * args , * * kwargs ) :
self . lazy_loaded_storages = { }
return super ( ) . __init__ ( * args , * * kwargs )
def forced_persistent_load ( self , saved_id ) :
assert isinstance ( saved_id , tuple )
typename = saved_id [ 0 ]
assert typename == " storage " , f " Unknown typename for persistent_load, expected ' storage ' but got ' { typename } ' "
2022-03-02 07:02:35 +01:00
storage_type , key , location , _ = saved_id [ 1 : ]
return LazyTensor ( storage_type , key , location )
2022-03-01 21:40:44 +01:00
def load ( self , * args , * * kwargs ) :
self . persistent_load = self . forced_persistent_load
retval = super ( ) . load ( * args , * * kwargs )
self . lazy_loaded_storages = { }
return retval
def _rebuild_tensor ( lazy_storage : LazyTensor , storage_offset , shape , stride ) :
lazy_storage . shape = shape
lazy_storage . stride = stride
2022-03-17 17:51:41 +01:00
dtype = lazy_storage . storage_type . dtype
if not isinstance ( dtype , torch . dtype ) :
dtype = lazy_storage . storage_type ( 0 ) . dtype
lazy_storage . dtype = dtype
2022-03-04 05:53:40 +01:00
lazy_storage . seek_offset = storage_offset if dtype is torch . bool else storage_offset * ( ( torch . finfo if dtype . is_floating_point else torch . iinfo ) ( dtype ) . bits >> 3 )
2022-03-01 21:40:44 +01:00
return lazy_storage
2022-03-03 17:17:59 +01:00
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
def _load_from_state_dict ( self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ) :
for hook in self . _load_state_dict_pre_hooks . values ( ) :
hook ( state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
persistent_buffers = { k : v for k , v in self . _buffers . items ( ) if k not in self . _non_persistent_buffers_set }
local_name_params = itertools . chain ( self . _parameters . items ( ) , persistent_buffers . items ( ) )
local_state = { k : v for k , v in local_name_params if v is not None }
for name , param in local_state . items ( ) :
key = prefix + name
if key in state_dict :
input_param = state_dict [ key ]
if not torch . overrides . is_tensor_like ( input_param ) :
error_msgs . append ( ' While copying the parameter named " {} " , '
' expected torch.Tensor or Tensor-like object from checkpoint but '
' received {} '
. format ( key , type ( input_param ) ) )
continue
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch . nn . parameter . is_lazy ( param )
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len ( param . shape ) == 0 and len ( input_param . shape ) == 1 :
input_param = input_param [ 0 ]
if not is_param_lazy and input_param . shape != param . shape :
# local shape should match the one in checkpoint
error_msgs . append ( ' size mismatch for {} : copying a param with shape {} from checkpoint, '
' the shape in current model is {} . '
. format ( key , input_param . shape , param . shape ) )
continue
try :
with torch . no_grad ( ) :
#param.copy_(input_param)
new_param = torch . nn . Parameter ( input_param , requires_grad = param . requires_grad ) # This line is new
if name in self . _parameters : # This line is new
self . _parameters [ name ] = new_param # This line is new
if name in persistent_buffers : # This line is new
self . _buffers [ name ] = new_param # This line is new
except Exception as ex :
error_msgs . append ( ' While copying the parameter named " {} " , '
' whose dimensions in the model are {} and '
' whose dimensions in the checkpoint are {} , '
' an exception occurred : {} . '
. format ( key , param . size ( ) , input_param . size ( ) , ex . args ) )
elif strict :
missing_keys . append ( key )
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
2022-03-17 19:10:51 +01:00
if hasattr ( Module , " set_extra_state " ) and getattr ( self . __class__ , " set_extra_state " , Module . set_extra_state ) is not Module . set_extra_state : # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
2022-03-03 17:17:59 +01:00
if extra_state_key in state_dict :
self . set_extra_state ( state_dict [ extra_state_key ] )
elif strict :
missing_keys . append ( extra_state_key )
elif strict and ( extra_state_key in state_dict ) :
unexpected_keys . append ( extra_state_key )
if strict :
for key in state_dict . keys ( ) :
if key . startswith ( prefix ) and key != extra_state_key :
input_name = key [ len ( prefix ) : ]
input_name = input_name . split ( ' . ' , 1 ) [ 0 ] # get the name of param/buffer/child
if input_name not in self . _modules and input_name not in local_state :
unexpected_keys . append ( key )
2022-03-01 21:40:44 +01:00
@contextlib.contextmanager
2022-03-03 17:17:59 +01:00
def use_lazy_torch_load ( enable = True , callback : Optional [ Callable ] = None , dematerialized_modules = False ) :
2022-03-01 21:40:44 +01:00
if not enable :
yield False
return
2022-03-02 01:30:22 +01:00
try :
old_unpickler = pickle . Unpickler
pickle . Unpickler = _LazyUnpickler
2022-03-01 21:40:44 +01:00
2022-03-02 01:30:22 +01:00
old_rebuild_tensor = torch . _utils . _rebuild_tensor
torch . _utils . _rebuild_tensor = _rebuild_tensor
2022-03-01 21:40:44 +01:00
2022-03-02 01:30:22 +01:00
old_torch_load = torch . load
2022-03-01 21:40:44 +01:00
2022-03-02 01:30:22 +01:00
def torch_load ( f , map_location = None , pickle_module = pickle , * * pickle_load_args ) :
retval = old_torch_load ( f = f , map_location = map_location , pickle_module = pickle_module , * * pickle_load_args )
if callback is not None :
callback ( retval , f = f , map_location = map_location , pickle_module = pickle_module , * * pickle_load_args )
return retval
2022-03-01 21:40:44 +01:00
2022-03-02 01:30:22 +01:00
torch . load = torch_load
2022-03-01 21:40:44 +01:00
2022-03-03 17:17:59 +01:00
if dematerialized_modules :
old_linear_init = torch . nn . Linear . __init__
old_embedding_init = torch . nn . Embedding . __init__
old_layernorm_init = torch . nn . LayerNorm . __init__
def linear_init ( self , * args , device = None , * * kwargs ) :
return old_linear_init ( self , * args , device = " meta " , * * kwargs )
def embedding_init ( self , * args , device = None , * * kwargs ) :
return old_embedding_init ( self , * args , device = " meta " , * * kwargs )
def layernorm_init ( self , * args , device = None , * * kwargs ) :
return old_layernorm_init ( self , * args , device = " meta " , * * kwargs )
torch . nn . Linear . __init__ = linear_init
torch . nn . Embedding . __init__ = embedding_init
torch . nn . LayerNorm . __init__ = layernorm_init
old_load_from_state_dict = torch . nn . Module . _load_from_state_dict
torch . nn . Module . _load_from_state_dict = _load_from_state_dict
2022-03-02 01:30:22 +01:00
yield True
2022-03-01 21:40:44 +01:00
2022-03-02 01:30:22 +01:00
finally :
pickle . Unpickler = old_unpickler
torch . _utils . _rebuild_tensor = old_rebuild_tensor
torch . load = old_torch_load
2022-03-03 17:17:59 +01:00
if dematerialized_modules :
torch . nn . Linear . __init__ = old_linear_init
torch . nn . Embedding . __init__ = old_embedding_init
torch . nn . LayerNorm . __init__ = old_layernorm_init
torch . nn . Module . _load_from_state_dict = old_load_from_state_dict