mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Download all shards correctly on aria2 and raise on bad load key
This commit is contained in:
@@ -288,11 +288,18 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
try:
|
try:
|
||||||
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
return AutoModelForCausalLM.from_pretrained(location, **tf_kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "out of memory" in traceback.format_exc().lower():
|
traceback_string = traceback.format_exc().lower()
|
||||||
|
|
||||||
|
if "out of memory" in traceback_string:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model corrupted or serious loading problem. Stop here.
|
||||||
|
if "invalid load key" in traceback_string:
|
||||||
|
logger.error("Invalid load key! Aborting.")
|
||||||
|
raise
|
||||||
|
|
||||||
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
logger.warning(f"Fell back to GPT2LMHeadModel due to {e}")
|
||||||
try:
|
try:
|
||||||
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
return GPT2LMHeadModel.from_pretrained(location, **tf_kwargs)
|
||||||
|
6
utils.py
6
utils.py
@@ -285,6 +285,7 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
# one of these out in the wild yet, probably due to how Safetensors has a
|
# one of these out in the wild yet, probably due to how Safetensors has a
|
||||||
# lot of benifits of sharding built in
|
# lot of benifits of sharding built in
|
||||||
for possible_filename in [
|
for possible_filename in [
|
||||||
|
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
|
||||||
transformers.modeling_utils.SAFE_WEIGHTS_NAME,
|
transformers.modeling_utils.SAFE_WEIGHTS_NAME,
|
||||||
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
|
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
|
||||||
transformers.modeling_utils.WEIGHTS_NAME
|
transformers.modeling_utils.WEIGHTS_NAME
|
||||||
@@ -299,7 +300,10 @@ def _transformers22_aria2_hook(pretrained_model_name_or_path: str, force_downloa
|
|||||||
if not filename:
|
if not filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if filename not in [transformers.modeling_utils.WEIGHTS_INDEX_NAME]:
|
if filename not in [
|
||||||
|
transformers.modeling_utils.SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
transformers.modeling_utils.WEIGHTS_INDEX_NAME
|
||||||
|
]:
|
||||||
# If the model isn't sharded, theres only one file to download
|
# If the model isn't sharded, theres only one file to download
|
||||||
filenames = [filename]
|
filenames = [filename]
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user