Apply tokenizer fixes to prompt_tuner.py
This commit is contained in:
parent
6758d5b538
commit
f7b799be56
|
@ -27,7 +27,7 @@ import torch.nn.functional as F
|
||||||
from torch.nn import Embedding, CrossEntropyLoss
|
from torch.nn import Embedding, CrossEntropyLoss
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import __version__ as transformers_version
|
from transformers import __version__ as transformers_version
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
from transformers import AutoTokenizer, GPT2Tokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM, PreTrainedModel, modeling_utils
|
||||||
import accelerate
|
import accelerate
|
||||||
import accelerate.utils
|
import accelerate.utils
|
||||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||||
|
@ -344,41 +344,38 @@ default_quiet = False
|
||||||
|
|
||||||
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
|
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
|
||||||
if(os.path.isdir(model_id)):
|
if(os.path.isdir(model_id)):
|
||||||
try:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
|
||||||
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||||
|
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
try:
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
|
||||||
|
except Exception as e:
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=revision, cache_dir="cache")
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _kai_no_prefix():
|
def _kai_no_prefix():
|
||||||
|
|
Loading…
Reference in New Issue