possible fix for cache dl thing

This commit is contained in:
somebody
2023-04-14 20:25:03 -05:00
parent 334c09606b
commit 38c53191d3

View File

@@ -171,16 +171,19 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
if utils.num_shards is None: if utils.num_shards is None:
# Save the pytorch_model.bin or model.safetensors of an unsharded model # Save the pytorch_model.bin or model.safetensors of an unsharded model
for possible_weight_name in [ any_success = False
possible_checkpoint_names = [
transformers.modeling_utils.WEIGHTS_NAME, transformers.modeling_utils.WEIGHTS_NAME,
"model.safetensors", "model.safetensors",
]: ]
for possible_checkpoint_name in possible_checkpoint_names:
try: try:
shutil.move( shutil.move(
os.path.realpath( os.path.realpath(
huggingface_hub.hf_hub_download( huggingface_hub.hf_hub_download(
self.model_name, self.model_name,
possible_weight_name, possible_checkpoint_name,
revision=utils.koboldai_vars.revision, revision=utils.koboldai_vars.revision,
cache_dir="cache", cache_dir="cache",
local_files_only=True, local_files_only=True,
@@ -191,12 +194,15 @@ class GenericHFTorchInferenceModel(HFTorchInferenceModel):
self.get_local_model_path( self.get_local_model_path(
ignore_existance=True ignore_existance=True
), ),
possible_weight_name, possible_checkpoint_name,
), ),
) )
any_success = True
except Exception: except Exception:
if possible_weight_name == "model.safetensors": pass
raise
if not any_success:
raise RuntimeError(f"Couldn't find any of {possible_checkpoint_names} in cache for {self.model_name} @ '{utils.koboldai_vars.revisison}'")
else: else:
# Handle saving sharded models # Handle saving sharded models