AutoTokenizer

Futureproofing for future tokenizers, for now this is not needed since everything uses GPT2. But when that changes we want to be prepared. Not all models have a proper tokenizer config, so if we can't find one we fall back to GPT2.
This commit is contained in:
henk717 2021-12-25 00:44:26 +01:00
parent 00a0cea077
commit e1cd34268b
1 changed files with 13 additions and 4 deletions

View File

@ -615,7 +615,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
if(not vars.noai):
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
try:
from transformers import GPTJModel
except:
@ -885,21 +885,30 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
# Download model from Huggingface if it does not exist, otherwise load locally
if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
elif(os.path.isdir(vars.model.replace('/', '_'))):
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
try:
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
else:
print("Model does not exist locally, attempting to download from Huggingface...")
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
with(maybe_use_float16()):
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
try: