diff --git a/prompt_tuner.py b/prompt_tuner.py index ea0efd3b..a958f882 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -15,11 +15,12 @@ import base64 import pickle import hashlib import itertools +from tqdm.auto import tqdm import torch import torch.nn.functional as F from torch.nn import Embedding, CrossEntropyLoss import transformers -from transformers import AutoTokenizer, GPT2TokenizerFast +from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM from mkultra.soft_prompt import SoftPrompt from typing import List, Optional, TextIO, Union @@ -27,6 +28,18 @@ from typing import List, Optional, TextIO, Union _PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel] +class _WTEDummy: + def __init__(self, model: transformers.PreTrainedModel): + self.model = model + + @property + def wte(self: "_WTEDummy"): + return self.model.get_input_embeddings() + + @wte.setter + def wte(self: "_WTEDummy", v): + self.model.set_input_embeddings(v) + class _WTEMixin: @property def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]): @@ -43,7 +56,7 @@ class UniversalPromptTuningMixin: model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs) if not hasattr(model, "transformer"): - model.transformer = _WTEMixin() + model.transformer = _WTEDummy(model) elif not hasattr(model.transformer, "wte"): assert isinstance(model.transformer, type) model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {}) @@ -248,6 +261,26 @@ class TrainerBase(abc.ABC): raise ConfigurationError(msg, **kwargs) def get_hf_checkpoint_metadata(self) -> bool: + REVISION = None + params = {} + if(os.path.isdir(self.data.ckpt_path)): + model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache") + elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))): + model_config = AutoConfig.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache") + else: + model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache") + params["tokenizer_id"] = self.data.ckpt_path + tokenizer = get_tokenizer(self.data.ckpt_path) + params["newlinemode"] = params.get( + "newlinemode", "s" if model_config.model_type == "xglm" else "n" + ) + params["max_batch_size"] = 2048 + with tokenizer._kai_no_prefix(): + params["eos_token"] = ( + [50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id) + ) + params["seq"] = 2048 + self.data.params = params return True def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase: @@ -445,6 +478,7 @@ class TrainerBase(abc.ABC): self.raise_configuration_error( "You have not set a soft prompt size.", code=6 ) + step = 0 else: # If we're resuming a soft-tuning session, the soft prompt tensor is # already in the save file and we just have to decode it. @@ -502,7 +536,7 @@ class TrainerBase(abc.ABC): if beta1 == 0.0: beta1 = None optimizer = transformers.Adafactor( - params=model.get_soft_params(), + params=(model.get_soft_params(),), scale_parameter=False, relative_step=False, warmup_init=False, @@ -540,19 +574,22 @@ class TrainerBase(abc.ABC): f, ) self.save_data() + + bar1 = tqdm(initial=step + 1, total=steps, desc="CURRENT TRAINING STEP") while step < steps: + step += 1 model.train() total_loss = total_grad = total_grad_norm = 0 - for i in range(self.data.gradient_accumulation_steps): - # Get the next sequence from the dataset - block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device) + # Get the next sequences from the dataset + block = torch.tensor(np.int32(self.get_batch(step, self.data.gradient_accumulation_steps))).to(model.transformer.wte.weight.device) + for sequence in tqdm(block, desc="GRADIENT ACCUMULATION", leave=False): # input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated) - input_ids = block[:-1].unsqueeze(0).detach() - labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach() + input_ids = sequence[:-1].unsqueeze(0).detach() + labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=sequence.device), sequence), dim=-1).unsqueeze(0).detach() # Give the context to the model and compare the model's output logits with the labels to compute the loss logits = model(input_ids=input_ids, labels=input_ids).logits @@ -579,12 +616,13 @@ class TrainerBase(abc.ABC): scheduler.step() optimizer.zero_grad() - step += 1 - # Save checkpoint every few steps if step == 1 or step % self.data.stparams["save_every"] == 0: save_mkusp(mean_loss, mean_grad_norm) + bar1.set_postfix({"loss": mean_loss, "grad_norm": mean_grad_norm, "learning_rate": lr}) + bar1.update() + class BasicTrainer(TrainerBase): class TrainerData(TrainerBase.TrainerData): @@ -652,12 +690,12 @@ class BasicTrainer(TrainerBase): ) ) sample_space = [ - k for k in range(self.data.params["n_vocab"]) if k not in special_tokens + k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens ] sample = rng.choice(sample_space, self.data.soft_in_dim, False) - return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)) + return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))) elif self.data.prompt_method == "tokens": - return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)) + return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))) self.raise_configuration_error( f"Unknown prompt method {repr(self.data.prompt_method)}", code=104 )