Merge pull request #390 from one-some/accelerate-offloading

Fix bleeding edge model loads and add lazyload fallback
This commit is contained in:
henk717
2023-07-09 01:02:12 +02:00
committed by GitHub
5 changed files with 92 additions and 33 deletions

View File

@@ -1400,6 +1400,7 @@ def general_startup(override_args=None):
parser.add_argument('-f', action='store', help="option for compatability with colab memory profiles") parser.add_argument('-f', action='store', help="option for compatability with colab memory profiles")
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen") parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen") parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
parser.add_argument("--panic", action='store_true', help="Disables falling back when loading fails.")
#args: argparse.Namespace = None #args: argparse.Namespace = None
if "pytest" in sys.modules and override_args is None: if "pytest" in sys.modules and override_args is None:

View File

@@ -90,6 +90,8 @@ class model_backend(HFTorchInferenceModel):
utils.module_names = list(metamodel.state_dict().keys()) utils.module_names = list(metamodel.state_dict().keys())
utils.named_buffers = list(metamodel.named_buffers(recurse=True)) utils.named_buffers = list(metamodel.named_buffers(recurse=True))
except Exception as e: except Exception as e:
if utils.args.panic:
raise e
logger.warning(f"Gave up on lazy loading due to {e}") logger.warning(f"Gave up on lazy loading due to {e}")
self.lazy_load = False self.lazy_load = False

View File

@@ -363,6 +363,8 @@ class HFTorchInferenceModel(HFInferenceModel):
return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs) return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs)
except Exception as e: except Exception as e:
logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.") logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.")
if utils.args.panic:
raise
# Try to determine model type from either AutoModel or falling back to legacy # Try to determine model type from either AutoModel or falling back to legacy
try: try:
@@ -381,11 +383,28 @@ class HFTorchInferenceModel(HFInferenceModel):
metamodel metamodel
) )
with lazy_loader.use_lazy_load( try:
enable=self.lazy_load, # Try to load with the lazyloader first...
# DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!! with lazy_loader.use_lazy_load(
dematerialized_modules=False, enable=self.lazy_load,
): # DO NOT DEMATERIALIZE MODULES / INIT WEIGHTS EMPTY!!! IT WILL EXPLODE!!!!!!!
dematerialized_modules=False,
):
model = AutoModelForCausalLM.from_pretrained(
location,
offload_folder="accelerate-disk-cache",
torch_dtype=self._get_target_dtype(),
**tf_kwargs,
)
except Exception as e:
# ...but fall back to stock HF if lazyloader fails.
if utils.args.panic:
raise
logger.error("Lazyloader failed, falling back to stock HF load. You may run out of RAM here. Details:")
logger.error(e)
logger.error(traceback.format_exc())
logger.info("Falling back to stock HF load...")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
location, location,
offload_folder="accelerate-disk-cache", offload_folder="accelerate-disk-cache",
@@ -414,6 +433,9 @@ class HFTorchInferenceModel(HFInferenceModel):
logger.error("Invalid load key! Aborting.") logger.error("Invalid load key! Aborting.")
raise raise
if utils.args.panic:
raise
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
logger.debug(traceback.format_exc()) logger.debug(traceback.format_exc())

View File

@@ -57,8 +57,10 @@ import _codecs
import os import os
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.storage import UntypedStorage from torch.storage import UntypedStorage
from modeling.patches import LazyloadPatches
# Safetensors is a dependency for the local version, TPU/Colab doesn't # Safetensors is a dependency for the local version, TPU/Colab doesn't
# support it yet. # support it yet.
@@ -237,6 +239,29 @@ class SafetensorsLazyTensor(LazyTensor):
self.checkpoint_file, tensor_key=self.key, device=self.location self.checkpoint_file, tensor_key=self.key, device=self.location
) )
def _patched_rebuild_from_type_v2(func, new_type, args, state):
"""A patched version of torch._tensor._rebuild_from_type_v2 that
does not attempt to convert `LazyTensor`s to `torch.Tensor`s."""
ret = func(*args)
# BEGIN PATCH
transformation_ok = isinstance(ret, LazyTensor) and new_type == Tensor
if type(ret) is not new_type and not transformation_ok:
# END PATCH
ret = ret.as_subclass(new_type)
# Tensor does define __setstate__ even though it doesn't define
# __getstate__. So only use __setstate__ if it is NOT the one defined
# on Tensor
if (
getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
is not Tensor.__setstate__
):
ret.__setstate__(state)
else:
ret = torch._utils._set_obj_state(ret, state)
return ret
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
def original_persistent_load(self, saved_id): def original_persistent_load(self, saved_id):
@@ -253,7 +278,7 @@ class RestrictedUnpickler(pickle.Unpickler):
elif module == "torch._utils" and name == "_rebuild_tensor_v2": elif module == "torch._utils" and name == "_rebuild_tensor_v2":
return torch._utils._rebuild_tensor_v2 return torch._utils._rebuild_tensor_v2
elif module == "torch._tensor" and name == "_rebuild_from_type_v2": elif module == "torch._tensor" and name == "_rebuild_from_type_v2":
return torch._tensor._rebuild_from_type_v2 return _patched_rebuild_from_type_v2
elif module == "torch" and name in ( elif module == "torch" and name in (
"DoubleStorage", "DoubleStorage",
"FloatStorage", "FloatStorage",
@@ -486,6 +511,8 @@ def use_lazy_load(
begin_time = time.time() begin_time = time.time()
try: try:
LazyloadPatches.__enter__()
old_rebuild_tensor = torch._utils._rebuild_tensor old_rebuild_tensor = torch._utils._rebuild_tensor
torch._utils._rebuild_tensor = _rebuild_tensor torch._utils._rebuild_tensor = _rebuild_tensor
@@ -553,6 +580,7 @@ def use_lazy_load(
yield True yield True
finally: finally:
LazyloadPatches.__exit__(None, None, None)
torch._utils._rebuild_tensor = old_rebuild_tensor torch._utils._rebuild_tensor = old_rebuild_tensor
torch.load = old_torch_load torch.load = old_torch_load

View File

@@ -10,7 +10,9 @@ from transformers import (
PreTrainedModel, PreTrainedModel,
modeling_utils, modeling_utils,
) )
from modeling.lazy_loader import LazyTensor
import torch
import modeling
import utils import utils
@@ -126,27 +128,16 @@ def patch_transformers_generation() -> None:
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
def patch_transformers_for_lazyload() -> None: class LazyloadPatches:
""" old_load_state_dict = transformers.modeling_utils._load_state_dict_into_meta_model
Most of the code is modified code from the Accelerate and Transformers
projects, made by HuggingFace. The license for these projects are as follows:
---
Copyright The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); def __enter__() -> None:
you may not use this file except in compliance with the License. transformers.modeling_utils._load_state_dict_into_meta_model = (
You may obtain a copy of the License at LazyloadPatches._load_state_dict_into_meta_model
)
http://www.apache.org/licenses/LICENSE-2.0 def __exit__(exc_type, exc_value, exc_traceback) -> None:
transformers.modeling_utils._load_state_dict_into_meta_model = LazyloadPatches.old_load_state_dict
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from accelerate.utils import set_module_tensor_to_device, offload_weight
def _load_state_dict_into_meta_model( def _load_state_dict_into_meta_model(
model, model,
@@ -167,6 +158,26 @@ def patch_transformers_for_lazyload() -> None:
is_safetensors=False, is_safetensors=False,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
): ):
"""
This is modified code from the Accelerate and Transformers projects,
made by HuggingFace. The license for these projects are as follows:
---
Copyright The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from accelerate.utils import offload_weight, set_module_tensor_to_device
is_quantized = is_quantized or load_in_8bit is_quantized = is_quantized or load_in_8bit
if is_quantized: if is_quantized:
@@ -201,7 +212,7 @@ def patch_transformers_for_lazyload() -> None:
), ),
): ):
if isinstance(param, LazyTensor): if isinstance(param, modeling.lazy_loader.LazyTensor):
# Should always be true # Should always be true
param = param.materialize(map_location="cpu") param = param.materialize(map_location="cpu")
utils.bar.update(1) utils.bar.update(1)
@@ -296,15 +307,10 @@ def patch_transformers_for_lazyload() -> None:
return error_msgs, offload_index, state_dict_index return error_msgs, offload_index, state_dict_index
transformers.modeling_utils._load_state_dict_into_meta_model = (
_load_state_dict_into_meta_model
)
def patch_transformers(use_tpu: bool) -> None: def patch_transformers(use_tpu: bool) -> None:
patch_transformers_download() patch_transformers_download()
patch_transformers_loader() patch_transformers_loader()
if not use_tpu: if not use_tpu:
patch_transformers_generation() patch_transformers_generation()
patch_transformers_for_lazyload()