mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #390 from one-some/accelerate-offloading
Fix bleeding edge model loads and add lazyload fallback
This commit is contained in:
@@ -1400,6 +1400,7 @@ def general_startup(override_args=None):
|
||||
parser.add_argument('-f', action='store', help="option for compatability with colab memory profiles")
|
||||
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
|
||||
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
|
||||
parser.add_argument("--panic", action='store_true', help="Disables falling back when loading fails.")
|
||||
|
||||
#args: argparse.Namespace = None
|
||||
if "pytest" in sys.modules and override_args is None:
|
||||
|
@@ -90,6 +90,8 @@ class model_backend(HFTorchInferenceModel):
|
||||
utils.module_names = list(metamodel.state_dict().keys())
|
||||
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||
except Exception as e:
|
||||
if utils.args.panic:
|
||||
raise e
|
||||
logger.warning(f"Gave up on lazy loading due to {e}")
|
||||
self.lazy_load = False
|
||||
|
||||
|
@@ -363,6 +363,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
|
||||
if utils.args.panic:
|
||||
raise
|
||||
|
||||
# Try to determine model type from either AutoModel or falling back to legacy
|
||||
try:
|
||||
@@ -381,11 +383,28 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
metamodel
|
||||
)
|
||||
|
||||
with lazy_loader.use_lazy_load(
|
||||
enable=self.lazy_load,
|
||||
# DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!!
|
||||
dematerialized_modules=False,
|
||||
):
|
||||
try:
|
||||
# Try to load with the lazyloader first...
|
||||
with lazy_loader.use_lazy_load(
|
||||
enable=self.lazy_load,
|
||||
# DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!!
|
||||
dematerialized_modules=False,
|
||||
):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
offload_folder="accelerate-disk-cache",
|
||||
torch_dtype=self._get_target_dtype(),
|
||||
**tf_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
# ...but fall back to stock HF if lazyloader fails.
|
||||
if utils.args.panic:
|
||||
raise
|
||||
logger.error("Lazyloader failed, falling back to stock HF load. You may run out of RAM here. Details:")
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.info("Falling back to stock HF load...")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
offload_folder="accelerate-disk-cache",
|
||||
@@ -414,6 +433,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
logger.error("Invalid load key! Aborting.")
|
||||
raise
|
||||
|
||||
if utils.args.panic:
|
||||
raise
|
||||
|
||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
|
@@ -57,8 +57,10 @@ import _codecs
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.storage import UntypedStorage
|
||||
from modeling.patches import LazyloadPatches
|
||||
|
||||
# Safetensors is a dependency for the local version, TPU/Colab doesn't
|
||||
# support it yet.
|
||||
@@ -237,6 +239,29 @@ class SafetensorsLazyTensor(LazyTensor):
|
||||
self.checkpoint_file, tensor_key=self.key, device=self.location
|
||||
)
|
||||
|
||||
def _patched_rebuild_from_type_v2(func, new_type, args, state):
|
||||
"""A patched version of torch._tensor._rebuild_from_type_v2 that
|
||||
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
|
||||
|
||||
ret = func(*args)
|
||||
|
||||
# BEGIN PATCH
|
||||
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
|
||||
if type(ret) is not new_type and not transformation_ok:
|
||||
# END PATCH
|
||||
ret = ret.as_subclass(new_type)
|
||||
|
||||
# Tensor does define __setstate__ even though it doesn't define
|
||||
# __getstate__. So only use __setstate__ if it is NOT the one defined
|
||||
# on Tensor
|
||||
if (
|
||||
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
|
||||
is not Tensor.__setstate__
|
||||
):
|
||||
ret.__setstate__(state)
|
||||
else:
|
||||
ret = torch._utils._set_obj_state(ret, state)
|
||||
return ret
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def original_persistent_load(self, saved_id):
|
||||
@@ -253,7 +278,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||
return torch._utils._rebuild_tensor_v2
|
||||
elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
|
||||
return torch._tensor._rebuild_from_type_v2
|
||||
return _patched_rebuild_from_type_v2
|
||||
elif module == "torch" and name in (
|
||||
"DoubleStorage",
|
||||
"FloatStorage",
|
||||
@@ -486,6 +511,8 @@ def use_lazy_load(
|
||||
begin_time = time.time()
|
||||
|
||||
try:
|
||||
LazyloadPatches.__enter__()
|
||||
|
||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||
|
||||
@@ -553,6 +580,7 @@ def use_lazy_load(
|
||||
yield True
|
||||
|
||||
finally:
|
||||
LazyloadPatches.__exit__(None, None, None)
|
||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||
torch.load = old_torch_load
|
||||
|
||||
|
@@ -10,7 +10,9 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
modeling_utils,
|
||||
)
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import torch
|
||||
import modeling
|
||||
|
||||
import utils
|
||||
|
||||
@@ -126,27 +128,16 @@ def patch_transformers_generation() -> None:
|
||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
"""
|
||||
Most of the code is modified code from the Accelerate and Transformers
|
||||
projects, made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
class LazyloadPatches:
|
||||
old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
def __enter__() -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
LazyloadPatches._load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import torch
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
def __exit__(exc_type, exc_value, exc_traceback) -> None:
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
@@ -167,6 +158,26 @@ def patch_transformers_for_lazyload() -> None:
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
"""
|
||||
This is modified code from the Accelerate and Transformers projects,
|
||||
made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from accelerate.utils import offload_weight, set_module_tensor_to_device
|
||||
|
||||
is_quantized = is_quantized or load_in_8bit
|
||||
|
||||
if is_quantized:
|
||||
@@ -201,7 +212,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
),
|
||||
):
|
||||
|
||||
if isinstance(param, LazyTensor):
|
||||
if isinstance(param, modeling.lazy_loader.LazyTensor):
|
||||
# Should always be true
|
||||
param = param.materialize(map_location="cpu")
|
||||
utils.bar.update(1)
|
||||
@@ -296,15 +307,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = (
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
|
||||
def patch_transformers(use_tpu: bool) -> None:
|
||||
patch_transformers_download()
|
||||
patch_transformers_loader()
|
||||
|
||||
if not use_tpu:
|
||||
patch_transformers_generation()
|
||||
patch_transformers_for_lazyload()
|
||||
patch_transformers_generation()
|
Reference in New Issue
Block a user