mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Basic
This commit is contained in:
@@ -129,15 +129,34 @@ def patch_transformers_generation() -> None:
|
||||
|
||||
|
||||
class LazyloadPatches:
|
||||
class StateDictFacade(dict):
|
||||
def __init__(self, state_dict):
|
||||
self.update(state_dict)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return super().__getitem__(name).materialize(map_location="cuda:0")
|
||||
|
||||
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
|
||||
torch_old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||
|
||||
def __enter__() -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
LazyloadPatches._load_state_dict_into_meta_model
|
||||
)
|
||||
torch.nn.Module._load_from_state_dict = LazyloadPatches._torch_load_from_state_dict
|
||||
# torch.nn.Module._load_from_state_dict = _agn
|
||||
|
||||
def __exit__(exc_type, exc_value, exc_traceback) -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
|
||||
torch.nn.Module._load_from_state_dict = LazyloadPatches.torch_old_load_from_state_dict
|
||||
|
||||
def _torch_load_from_state_dict(self, state_dict, *args, **kwargs):
|
||||
return LazyloadPatches.torch_old_load_from_state_dict(
|
||||
self,
|
||||
LazyloadPatches.StateDictFacade(state_dict),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
|
Reference in New Issue
Block a user