Fix remaining problems in prompt_tuner.py

This commit is contained in:
vfbd 2022-08-22 17:30:49 -04:00
parent f79926b73d
commit 584056b6d5

View File

@ -15,11 +15,12 @@ import base64
import pickle
import hashlib
import itertools
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss
import transformers
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
from mkultra.soft_prompt import SoftPrompt
from typing import List, Optional, TextIO, Union
@ -27,6 +28,18 @@ from typing import List, Optional, TextIO, Union
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
class _WTEDummy:
def __init__(self, model: transformers.PreTrainedModel):
self.model = model
@property
def wte(self: "_WTEDummy"):
return self.model.get_input_embeddings()
@wte.setter
def wte(self: "_WTEDummy", v):
self.model.set_input_embeddings(v)
class _WTEMixin:
@property
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
@ -43,7 +56,7 @@ class UniversalPromptTuningMixin:
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
if not hasattr(model, "transformer"):
model.transformer = _WTEMixin()
model.transformer = _WTEDummy(model)
elif not hasattr(model.transformer, "wte"):
assert isinstance(model.transformer, type)
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
@ -248,6 +261,26 @@ class TrainerBase(abc.ABC):
raise ConfigurationError(msg, **kwargs)
def get_hf_checkpoint_metadata(self) -> bool:
REVISION = None
params = {}
if(os.path.isdir(self.data.ckpt_path)):
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))):
model_config = AutoConfig.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
else:
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
params["tokenizer_id"] = self.data.ckpt_path
tokenizer = get_tokenizer(self.data.ckpt_path)
params["newlinemode"] = params.get(
"newlinemode", "s" if model_config.model_type == "xglm" else "n"
)
params["max_batch_size"] = 2048
with tokenizer._kai_no_prefix():
params["eos_token"] = (
[50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id)
)
params["seq"] = 2048
self.data.params = params
return True
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
@ -445,6 +478,7 @@ class TrainerBase(abc.ABC):
self.raise_configuration_error(
"You have not set a soft prompt size.", code=6
)
step = 0
else:
# If we're resuming a soft-tuning session, the soft prompt tensor is
# already in the save file and we just have to decode it.
@ -502,7 +536,7 @@ class TrainerBase(abc.ABC):
if beta1 == 0.0:
beta1 = None
optimizer = transformers.Adafactor(
params=model.get_soft_params(),
params=(model.get_soft_params(),),
scale_parameter=False,
relative_step=False,
warmup_init=False,
@ -540,19 +574,22 @@ class TrainerBase(abc.ABC):
f,
)
self.save_data()
bar1 = tqdm(initial=step + 1, total=steps, desc="CURRENT TRAINING STEP")
while step < steps:
step += 1
model.train()
total_loss = total_grad = total_grad_norm = 0
for i in range(self.data.gradient_accumulation_steps):
# Get the next sequence from the dataset
block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device)
# Get the next sequences from the dataset
block = torch.tensor(np.int32(self.get_batch(step, self.data.gradient_accumulation_steps))).to(model.transformer.wte.weight.device)
for sequence in tqdm(block, desc="GRADIENT ACCUMULATION", leave=False):
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
input_ids = block[:-1].unsqueeze(0).detach()
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach()
input_ids = sequence[:-1].unsqueeze(0).detach()
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=sequence.device), sequence), dim=-1).unsqueeze(0).detach()
# Give the context to the model and compare the model's output logits with the labels to compute the loss
logits = model(input_ids=input_ids, labels=input_ids).logits
@ -579,12 +616,13 @@ class TrainerBase(abc.ABC):
scheduler.step()
optimizer.zero_grad()
step += 1
# Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm)
bar1.set_postfix({"loss": mean_loss, "grad_norm": mean_grad_norm, "learning_rate": lr})
bar1.update()
class BasicTrainer(TrainerBase):
class TrainerData(TrainerBase.TrainerData):
@ -652,12 +690,12 @@ class BasicTrainer(TrainerBase):
)
)
sample_space = [
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens
]
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)))
elif self.data.prompt_method == "tokens":
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)))
self.raise_configuration_error(
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
)