This commit is contained in:
somebody
2023-07-23 20:54:04 -05:00
parent 70d2da55e5
commit 1df03d9a27
3 changed files with 79 additions and 39 deletions

View File

@@ -129,15 +129,34 @@ def patch_transformers_generation() -> None:
class LazyloadPatches:
class StateDictFacade(dict):
def __init__(self, state_dict):
self.update(state_dict)
def __getitem__(self, name):
return super().__getitem__(name).materialize(map_location="cuda:0")
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
torch_old_load_from_state_dict = torch.nn.Module._load_from_state_dict
def __enter__() -> None:
transformers.modeling_utils._load_state_dict_into_meta_model = (
LazyloadPatches._load_state_dict_into_meta_model
)
torch.nn.Module._load_from_state_dict = LazyloadPatches._torch_load_from_state_dict
# torch.nn.Module._load_from_state_dict = _agn
def __exit__(exc_type, exc_value, exc_traceback) -> None:
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
torch.nn.Module._load_from_state_dict = LazyloadPatches.torch_old_load_from_state_dict
def _torch_load_from_state_dict(self, state_dict, *args, **kwargs):
return LazyloadPatches.torch_old_load_from_state_dict(
self,
LazyloadPatches.StateDictFacade(state_dict),
*args,
**kwargs
)
def _load_state_dict_into_meta_model(
model,