Apply tokenizer fixes to prompt_tuner.py
This commit is contained in:
parent
6758d5b538
commit
f7b799be56
|
@ -27,7 +27,7 @@ import torch.nn.functional as F
|
|||
from torch.nn import Embedding, CrossEntropyLoss
|
||||
import transformers
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
||||
import accelerate
|
||||
import accelerate.utils
|
||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||
|
@ -344,41 +344,38 @@ default_quiet = False
|
|||
|
||||
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
|
||||
if(os.path.isdir(model_id)):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
try:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
try:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
try:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _kai_no_prefix():
|
||||
|
|
Loading…
Reference in New Issue