Apply tokenizer fixes to prompt_tuner.py

This commit is contained in:
vfbd 2022-10-21 17:06:17 -04:00
parent 6758d5b538
commit f7b799be56
1 changed files with 16 additions and 19 deletions

View File

@ -27,7 +27,7 @@ import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss from torch.nn import Embedding, CrossEntropyLoss
import transformers import transformers
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
import accelerate import accelerate
import accelerate.utils import accelerate.utils
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
@ -344,41 +344,38 @@ default_quiet = False
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase: def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
if(os.path.isdir(model_id)): if(os.path.isdir(model_id)):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
pass
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e: except Exception as e:
try: try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache") tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e: except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
try: try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e: except Exception as e:
pass tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
try: try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
except Exception as e: except Exception as e:
try: try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache") tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e: except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache") tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e: except Exception as e:
pass tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e: except Exception as e:
try: try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache") tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e: except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache") try:
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
@contextlib.contextmanager @contextlib.contextmanager
def _kai_no_prefix(): def _kai_no_prefix():