From c2ee30af32c68e3915585feea3d9d19c8ae70e4c Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 8 Jul 2023 14:04:46 -0500 Subject: [PATCH] Add --panic to raise when loading fails --- aiserver.py | 1 + modeling/inference_models/generic_hf_torch/class.py | 2 ++ modeling/inference_models/hf_torch.py | 5 +++++ 3 files changed, 8 insertions(+) diff --git a/aiserver.py b/aiserver.py index 49223b3a..74b8bca8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1400,6 +1400,7 @@ def general_startup(override_args=None): parser.add_argument('-f', action='store', help="option for compatability with colab memory profiles") parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen") parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen") + parser.add_argument("--panic", action='store_true', help="Disables falling back when loading fails.") #args: argparse.Namespace = None if "pytest" in sys.modules and override_args is None: diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index ad17b85b..8f024ea1 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -90,6 +90,8 @@ class model_backend(HFTorchInferenceModel): utils.module_names = list(metamodel.state_dict().keys()) utils.named_buffers = list(metamodel.named_buffers(recurse=True)) except Exception as e: + if utils.args.panic: + raise e logger.warning(f"Gave up on lazy loading due to {e}") self.lazy_load = False diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 84d3447e..2249a87a 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -363,6 +363,8 @@ class HFTorchInferenceModel(HFInferenceModel): return GPTNeoForCausalLM.from_pretrained(location, **tf_kwargs) except Exception as e: logger.warning(f"{self.model_name} is a no-go; {e} - Falling back to auto.") + if utils.args.panic: + raise e # Try to determine model type from either AutoModel or falling back to legacy try: @@ -414,6 +416,9 @@ class HFTorchInferenceModel(HFInferenceModel): logger.error("Invalid load key! Aborting.") raise + if utils.args.panic: + raise e + logger.warning(f"Fell back to GPT2LMHeadModel due to {e}") logger.debug(traceback.format_exc())