From 38c53191d31fc48f2c4ee612ced12e7c62a4da73 Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 14 Apr 2023 20:25:03 -0500 Subject: [PATCH] possible fix for cache dl thing --- modeling/inference_models/generic_hf_torch.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch.py b/modeling/inference_models/generic_hf_torch.py index 6a8964ec..7598a424 100644 --- a/modeling/inference_models/generic_hf_torch.py +++ b/modeling/inference_models/generic_hf_torch.py @@ -171,16 +171,19 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): if utils.num_shards is None: # Save the pytorch_model.bin or model.safetensors of an unsharded model - for possible_weight_name in [ + any_success = False + possible_checkpoint_names = [ transformers.modeling_utils.WEIGHTS_NAME, "model.safetensors", - ]: + ] + + for possible_checkpoint_name in possible_checkpoint_names: try: shutil.move( os.path.realpath( huggingface_hub.hf_hub_download( self.model_name, - possible_weight_name, + possible_checkpoint_name, revision=utils.koboldai_vars.revision, cache_dir="cache", local_files_only=True, @@ -191,12 +194,15 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel): self.get_local_model_path( ignore_existance=True ), - possible_weight_name, + possible_checkpoint_name, ), ) + any_success = True except Exception: - if possible_weight_name == "model.safetensors": - raise + pass + + if not any_success: + raise RuntimeError(f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'") else: # Handle saving sharded models