mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fall back to unpatched HF
This commit is contained in:
@@ -10,7 +10,9 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
modeling_utils,
|
||||
)
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import torch
|
||||
import modeling
|
||||
|
||||
import utils
|
||||
|
||||
@@ -126,27 +128,16 @@ def patch_transformers_generation() -> None:
|
||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
"""
|
||||
Most of the code is modified code from the Accelerate and Transformers
|
||||
projects, made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
class LazyloadPatches:
|
||||
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
def __enter__() -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
LazyloadPatches._load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import torch
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
def __exit__(exc_type, exc_value, exc_traceback) -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
@@ -167,6 +158,26 @@ def patch_transformers_for_lazyload() -> None:
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
"""
|
||||
This is modified code from the Accelerate and Transformers projects,
|
||||
made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from accelerate.utils import offload_weight, set_module_tensor_to_device
|
||||
|
||||
is_quantized = is_quantized or load_in_8bit
|
||||
|
||||
if is_quantized:
|
||||
@@ -201,7 +212,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
),
|
||||
):
|
||||
|
||||
if isinstance(param, LazyTensor):
|
||||
if isinstance(param, modeling.lazy_loader.LazyTensor):
|
||||
# Should always be true
|
||||
param = param.materialize(map_location="cpu")
|
||||
utils.bar.update(1)
|
||||
@@ -296,15 +307,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
|
||||
def patch_transformers(use_tpu: bool) -> None:
|
||||
patch_transformers_download()
|
||||
patch_transformers_loader()
|
||||
|
||||
if not use_tpu:
|
||||
patch_transformers_generation()
|
||||
patch_transformers_for_lazyload()
|
||||
patch_transformers_generation()
|
Reference in New Issue
Block a user