Fix remaining problems in prompt_tuner.py

This commit is contained in:
vfbd
2022-08-22 17:30:49 -04:00
parent f79926b73d
commit 584056b6d5

View File

@@ -15,11 +15,12 @@ import base64
import pickle import pickle
import hashlib import hashlib
import itertools import itertools
from tqdm.auto import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss from torch.nn import Embedding, CrossEntropyLoss
import transformers import transformers
from transformers import AutoTokenizer, GPT2TokenizerFast from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
from mkultra.soft_prompt import SoftPrompt from mkultra.soft_prompt import SoftPrompt
from typing import List, Optional, TextIO, Union from typing import List, Optional, TextIO, Union
@@ -27,6 +28,18 @@ from typing import List, Optional, TextIO, Union
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel] _PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
class _WTEDummy:
def __init__(self, model: transformers.PreTrainedModel):
self.model = model
@property
def wte(self: "_WTEDummy"):
return self.model.get_input_embeddings()
@wte.setter
def wte(self: "_WTEDummy", v):
self.model.set_input_embeddings(v)
class _WTEMixin: class _WTEMixin:
@property @property
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]): def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
@@ -43,7 +56,7 @@ class UniversalPromptTuningMixin:
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs) model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
if not hasattr(model, "transformer"): if not hasattr(model, "transformer"):
model.transformer = _WTEMixin() model.transformer = _WTEDummy(model)
elif not hasattr(model.transformer, "wte"): elif not hasattr(model.transformer, "wte"):
assert isinstance(model.transformer, type) assert isinstance(model.transformer, type)
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {}) model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
@@ -248,6 +261,26 @@ class TrainerBase(abc.ABC):
raise ConfigurationError(msg, **kwargs) raise ConfigurationError(msg, **kwargs)
def get_hf_checkpoint_metadata(self) -> bool: def get_hf_checkpoint_metadata(self) -> bool:
REVISION = None
params = {}
if(os.path.isdir(self.data.ckpt_path)):
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))):
model_config = AutoConfig.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
else:
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
params["tokenizer_id"] = self.data.ckpt_path
tokenizer = get_tokenizer(self.data.ckpt_path)
params["newlinemode"] = params.get(
"newlinemode", "s" if model_config.model_type == "xglm" else "n"
)
params["max_batch_size"] = 2048
with tokenizer._kai_no_prefix():
params["eos_token"] = (
[50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id)
)
params["seq"] = 2048
self.data.params = params
return True return True
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase: def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
@@ -445,6 +478,7 @@ class TrainerBase(abc.ABC):
self.raise_configuration_error( self.raise_configuration_error(
"You have not set a soft prompt size.", code=6 "You have not set a soft prompt size.", code=6
) )
step = 0
else: else:
# If we're resuming a soft-tuning session, the soft prompt tensor is # If we're resuming a soft-tuning session, the soft prompt tensor is
# already in the save file and we just have to decode it. # already in the save file and we just have to decode it.
@@ -502,7 +536,7 @@ class TrainerBase(abc.ABC):
if beta1 == 0.0: if beta1 == 0.0:
beta1 = None beta1 = None
optimizer = transformers.Adafactor( optimizer = transformers.Adafactor(
params=model.get_soft_params(), params=(model.get_soft_params(),),
scale_parameter=False, scale_parameter=False,
relative_step=False, relative_step=False,
warmup_init=False, warmup_init=False,
@@ -540,19 +574,22 @@ class TrainerBase(abc.ABC):
f, f,
) )
self.save_data() self.save_data()
bar1 = tqdm(initial=step + 1, total=steps, desc="CURRENT TRAINING STEP")
while step < steps: while step < steps:
step += 1
model.train() model.train()
total_loss = total_grad = total_grad_norm = 0 total_loss = total_grad = total_grad_norm = 0
for i in range(self.data.gradient_accumulation_steps): # Get the next sequences from the dataset
# Get the next sequence from the dataset block = torch.tensor(np.int32(self.get_batch(step, self.data.gradient_accumulation_steps))).to(model.transformer.wte.weight.device)
block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device)
for sequence in tqdm(block, desc="GRADIENT ACCUMULATION", leave=False):
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated) # input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
input_ids = block[:-1].unsqueeze(0).detach() input_ids = sequence[:-1].unsqueeze(0).detach()
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach() labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=sequence.device), sequence), dim=-1).unsqueeze(0).detach()
# Give the context to the model and compare the model's output logits with the labels to compute the loss # Give the context to the model and compare the model's output logits with the labels to compute the loss
logits = model(input_ids=input_ids, labels=input_ids).logits logits = model(input_ids=input_ids, labels=input_ids).logits
@@ -579,12 +616,13 @@ class TrainerBase(abc.ABC):
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
step += 1
# Save checkpoint every few steps # Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0: if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm) save_mkusp(mean_loss, mean_grad_norm)
bar1.set_postfix({"loss": mean_loss, "grad_norm": mean_grad_norm, "learning_rate": lr})
bar1.update()
class BasicTrainer(TrainerBase): class BasicTrainer(TrainerBase):
class TrainerData(TrainerBase.TrainerData): class TrainerData(TrainerBase.TrainerData):
@@ -652,12 +690,12 @@ class BasicTrainer(TrainerBase):
) )
) )
sample_space = [ sample_space = [
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens
] ]
sample = rng.choice(sample_space, self.data.soft_in_dim, False) sample = rng.choice(sample_space, self.data.soft_in_dim, False)
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)) return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)))
elif self.data.prompt_method == "tokens": elif self.data.prompt_method == "tokens":
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)) return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)))
self.raise_configuration_error( self.raise_configuration_error(
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104 f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
) )