Fix some more typos in prompt_tuner.py

This commit is contained in:
vfbd 2022-08-22 16:51:09 -04:00
parent a49a633164
commit f79926b73d
1 changed files with 12 additions and 12 deletions

View File

@ -127,28 +127,28 @@ def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBa
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
elif(os.path.isdir("models/{}".format(model_id.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache")
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(model_id.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=revision, cache_dir="cache")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=revision, cache_dir="cache", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=revision, cache_dir="cache")
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
@ -474,20 +474,20 @@ class TrainerBase(abc.ABC):
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))):
try:
model = AutoPromptTuningLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=REVISION, cache_dir="cache")
model = AutoPromptTuningLM.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=REVISION, cache_dir="cache")
model = GPTNeoPromptTuningLM.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
else:
try:
model = AutoPromptTuningLM.from_pretrained(vars.model, revision=REVISION, cache_dir="cache")
model = AutoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained(vars.model, revision=REVISION, cache_dir="cache")
model = GPTNeoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
if step == 0:
soft_embeddings = self.get_initial_soft_embeddings(model)