Apply tokenizer fixes to prompt_tuner.py

This commit is contained in:
vfbd 2022-10-21 17:06:17 -04:00
parent 6758d5b538
commit f7b799be56
1 changed files with 16 additions and 19 deletions

View File

@ -27,7 +27,7 @@ import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss
import transformers
from transformers import __version__ as transformers_version
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
import accelerate
import accelerate.utils
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
@ -344,41 +344,38 @@ default_quiet = False
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
if(os.path.isdir(model_id)):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
pass
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
pass
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
try:
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
@contextlib.contextmanager
def _kai_no_prefix():