From 3928d86339b492466594a22c521a0d68f08024b3 Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 8 Jul 2023 14:36:45 -0500 Subject: [PATCH] Fall back to unpatched HF --- modeling/inference_models/hf_torch.py | 31 ++++++++++---- modeling/lazy_loader.py | 4 ++ modeling/patches.py | 60 +++++++++++++++------------ 3 files changed, 61 insertions(+), 34 deletions(-) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 2249a87a..fb9fe39e 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -364,7 +364,7 @@ class HFTorchInferenceModel(HFInferenceModel): except Exception as e: logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.") if utils.args.panic: - raise e + raise # Try to determine model type from either AutoModel or falling back to legacy try: @@ -383,11 +383,28 @@ class HFTorchInferenceModel(HFInferenceModel): metamodel ) - with lazy_loader.use_lazy_load( - enable=self.lazy_load, - # DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!! - dematerialized_modules=False, - ): + try: + # Try to load with the lazyloader first... + with lazy_loader.use_lazy_load( + enable=self.lazy_load, + # DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!! + dematerialized_modules=False, + ): + model = AutoModelForCausalLM.from_pretrained( + location, + offload_folder="accelerate-disk-cache", + torch_dtype=self._get_target_dtype(), + **tf_kwargs, + ) + except Exception as e: + # ...but fall back to stock HF if lazyloader fails. + if utils.args.panic: + raise + logger.error("Lazyloader failed, falling back to stock HF load. You may run out of RAM here. Details:") + logger.error(e) + logger.error(traceback.format_exc()) + logger.info("Falling back to stock HF load...") + model = AutoModelForCausalLM.from_pretrained( location, offload_folder="accelerate-disk-cache", @@ -417,7 +434,7 @@ class HFTorchInferenceModel(HFInferenceModel): raise if utils.args.panic: - raise e + raise logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.debug(traceback.format_exc()) diff --git a/modeling/lazy_loader.py b/modeling/lazy_loader.py index 2af0ae51..69e0d948 100644 --- a/modeling/lazy_loader.py +++ b/modeling/lazy_loader.py @@ -60,6 +60,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from torch import Tensor from torch.nn import Module from torch.storage import UntypedStorage +from modeling.patches import LazyloadPatches # Safetensors is a dependency for the local version, TPU/Colab doesn't # support it yet. @@ -510,6 +511,8 @@ def use_lazy_load( begin_time = time.time() try: + LazyloadPatches.__enter__() + old_rebuild_tensor = torch._utils._rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor @@ -577,6 +580,7 @@ def use_lazy_load( yield True finally: + LazyloadPatches.__exit__(None, None, None) torch._utils._rebuild_tensor = old_rebuild_tensor torch.load = old_torch_load diff --git a/modeling/patches.py b/modeling/patches.py index 827e997a..a72d533a 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -10,7 +10,9 @@ from transformers import ( PreTrainedModel, modeling_utils, ) -from modeling.lazy_loader import LazyTensor + +import torch +import modeling import utils @@ -126,27 +128,16 @@ def patch_transformers_generation() -> None: transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init -def patch_transformers_for_lazyload() -> None: - """ - Most of the code is modified code from the Accelerate and Transformers - projects, made by HuggingFace. The license for these projects are as follows: - --- - Copyright The HuggingFace Team. All rights reserved. +class LazyloadPatches: + old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + def __enter__() -> None: + transformers.modeling_utils._load_state_dict_into_meta_model = ( + LazyloadPatches._load_state_dict_into_meta_model + ) - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - import torch - from accelerate.utils import set_module_tensor_to_device, offload_weight + def __exit__(exc_type, exc_value, exc_traceback) -> None: + transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict def _load_state_dict_into_meta_model( model, @@ -167,6 +158,26 @@ def patch_transformers_for_lazyload() -> None: is_safetensors=False, keep_in_fp32_modules=None, ): + """ + This is modified code from the Accelerate and Transformers projects, + made by HuggingFace. The license for these projects are as follows: + --- + Copyright The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + from accelerate.utils import offload_weight, set_module_tensor_to_device + is_quantized = is_quantized or load_in_8bit if is_quantized: @@ -201,7 +212,7 @@ def patch_transformers_for_lazyload() -> None: ), ): - if isinstance(param, LazyTensor): + if isinstance(param, modeling.lazy_loader.LazyTensor): # Should always be true param = param.materialize(map_location="cpu") utils.bar.update(1) @@ -296,15 +307,10 @@ def patch_transformers_for_lazyload() -> None: return error_msgs, offload_index, state_dict_index - transformers.modeling_utils._load_state_dict_into_meta_model = ( - _load_state_dict_into_meta_model - ) - def patch_transformers(use_tpu: bool) -> None: patch_transformers_download() patch_transformers_loader() if not use_tpu: - patch_transformers_generation() - patch_transformers_for_lazyload() + patch_transformers_generation() \ No newline at end of file