No safetensors for TPU

This commit is contained in:
Henk
2023-08-02 10:01:35 +02:00
parent 16017a3afc
commit c066494c70

View File

@@ -1303,7 +1303,7 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
model = AutoModelForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache", use_safetensors=False)
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.custmodpth, revision=koboldai_vars.revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(koboldai_vars.model.replace('/', '_')))):
@@ -1318,7 +1318,7 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(koboldai_vars.model.replace('/', '_')), revision=koboldai_vars.revision, cache_dir="cache")
model = AutoModelForCausalLM.from_pretrained("models/{}".format(koboldai_vars.model.replace('/', '_')), revision=koboldai_vars.revision, cache_dir="cache", use_safetensors=False)
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(koboldai_vars.model.replace('/', '_')), revision=koboldai_vars.revision, cache_dir="cache")
else:
@@ -1333,7 +1333,7 @@ def load_model(path: str, model_type: str, badwordsids=koboldai_settings.badword
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=koboldai_vars.revision, cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(koboldai_vars.model, revision=koboldai_vars.revision, cache_dir="cache")
model = AutoModelForCausalLM.from_pretrained(koboldai_vars.model, revision=koboldai_vars.revision, cache_dir="cache", use_safetensors=False)
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(koboldai_vars.model, revision=koboldai_vars.revision, cache_dir="cache")