diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 14ddd7af..3f7c3967 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -465,19 +465,25 @@ class HFTorchInferenceModel(HFInferenceModel): device_map: Dict[str, Union[str, int]] = {} @functools.lru_cache(maxsize=None) - def get_original_key(key): - return max( - ( - original_key - for original_key in utils.module_names - if original_key.endswith(key) - ), - key=len, - ) + def get_original_key(key) -> Optional[str]: + key_candidates = [ + original_key + for original_key in utils.module_names + if original_key.endswith(key) + ] + + if not key_candidates: + logger.debug(f"!!! No key candidates for {key}") + return None + + return max(key_candidates, key=len) for key, value in model_dict.items(): original_key = get_original_key(key) + if not original_key: + continue + if isinstance(value, lazy_loader.LazyTensor) and not any( original_key.startswith(n) for n in utils.layers_module_names ):