From f7b799be567292931f9b1683c9d55124d3462054 Mon Sep 17 00:00:00 2001 From: vfbd Date: Fri, 21 Oct 2022 17:06:17 -0400 Subject: [PATCH] Apply tokenizer fixes to prompt_tuner.py --- prompt_tuner.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/prompt_tuner.py b/prompt_tuner.py index 46092eac..f37a8718 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -27,7 +27,7 @@ import torch.nn.functional as F from torch.nn import Embedding, CrossEntropyLoss import transformers from transformers import __version__ as transformers_version -from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils +from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils import accelerate import accelerate.utils from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM @@ -344,41 +344,38 @@ default_quiet = False def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase: if(os.path.isdir(model_id)): - try: - tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") - except Exception as e: - pass try: tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False) except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache") + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache") + try: + tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") + except Exception as e: + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache") elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))): - try: - tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") - except Exception as e: - pass try: tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False) except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") + tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache") + try: + tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") + except Exception as e: + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache") else: - try: - tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") - except Exception as e: - pass try: tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False) except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache") + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache") + try: + tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") + except Exception as e: + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache") @contextlib.contextmanager def _kai_no_prefix():