Fallback to generic GPT2 Tokenizer

This commit is contained in:
henk717
2022-03-10 19:52:15 +01:00
parent 68281184bf
commit 172a548fa1

View File

@ -1384,29 +1384,38 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
if(os.path.isdir(vars.custmodpth)):
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e:
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
except ValueError as e:
except Exception as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
if not args.colab: