From e1cd34268bc770f7ef7722d7d0fdbe4b29d48331 Mon Sep 17 00:00:00 2001 From: henk717 Date: Sat, 25 Dec 2021 00:44:26 +0100 Subject: [PATCH] AutoTokenizer Futureproofing for future tokenizers, for now this is not needed since everything uses GPT2. But when that changes we want to be prepared. Not all models have a proper tokenizer config, so if we can't find one we fall back to GPT2. --- aiserver.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/aiserver.py b/aiserver.py index 02e718e8..b2fe91eb 100644 --- a/aiserver.py +++ b/aiserver.py @@ -615,7 +615,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END)) if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.noai): print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END)) - from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM + from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer try: from transformers import GPTJModel except: @@ -885,21 +885,30 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Download model from Huggingface if it does not exist, otherwise load locally if(os.path.isdir(vars.custmodpth)): with(maybe_use_float16()): - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") try: model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem) elif(os.path.isdir(vars.model.replace('/', '_'))): with(maybe_use_float16()): - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/") + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") try: model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) else: print("Model does not exist locally, attempting to download from Huggingface...") - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") + try: + tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") + except ValueError as e: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") with(maybe_use_float16()): tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") try: