From 05cf9b1dded8925e6b39bb81cc878e686c3befab Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 22 Aug 2022 16:43:02 -0400 Subject: [PATCH] Upload BasicTrainer class --- prompt_tuner.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/prompt_tuner.py b/prompt_tuner.py index fbecb4c4..f172c6a1 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -13,6 +13,8 @@ import uuid import datetime import base64 import pickle +import hashlib +import itertools import torch import torch.nn.functional as F from torch.nn import Embedding, CrossEntropyLoss @@ -582,3 +584,88 @@ class TrainerBase(abc.ABC): # Save checkpoint every few steps if step == 1 or step % self.data.stparams["save_every"] == 0: save_mkusp(mean_loss, mean_grad_norm) + + +class BasicTrainer(TrainerBase): + class TrainerData(TrainerBase.TrainerData): + def __init__(self): + super().__init__() + self.dataset_file: Optional[str] = None + self.initial_softprompt: Optional[List[int]] = None + + data: "BasicTrainer.TrainerData" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dataset: Optional[np.ndarray] = None + + def startup(self, step: int) -> None: + if self.get_num_sequences() < self.data.gradient_accumulation_steps: + self.raise_configuration_error( + "Your dataset is too small! gradient_accumulation_steps must be less than or equal to the number of sequences.", + code=101, + ) + if ( + self.data.prompt_method == "tokens" + and step < 0 + and self.data.initial_softprompt is None + ): + self.raise_configuration_error( + "You have not set an initial soft prompt string.", code=103 + ) + if self.data.prompt_method == "tokens" and step < 0: + self.data.soft_in_dim = len(self.data.initial_softprompt) + + def get_batch(self, step: int, size: int) -> np.ndarray: + return self.dataset[(step - 1) * size : step * size] + + def get_num_sequences(self) -> int: + if self.dataset is None: + if self.data.dataset_file is None or not os.path.exists( + self.data.dataset_file + ): + self.raise_configuration_error( + f"Dataset file not found at {repr(self.data.dataset_file)}", + code=102, + ) + self.dataset = np.load(self.data.dataset_file, mmap_mode="r") + assert self.dataset.ndim >= 2 + assert self.dataset.shape[0] >= 2 + return self.dataset.shape[0] + + def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt: + if self.data.prompt_method == "vocab_sample": + rng = np.random.Generator( + np.random.PCG64( + [ + self.data.prompt_seed, + int.from_bytes(hashlib.sha256(model.config.model_type.encode("utf8")).digest()[:4], "little"), + ] + ) + ) + tokenizer = self.get_tokenizer() + with tokenizer._kai_no_prefix(): + special_tokens = set( + itertools.chain.from_iterable( + tokenizer.encode(str(v)) + for v in tokenizer.special_tokens_map_extended.values() + ) + ) + sample_space = [ + k for k in range(self.data.params["n_vocab"]) if k not in special_tokens + ] + sample = rng.choice(sample_space, self.data.soft_in_dim, False) + return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)) + elif self.data.prompt_method == "tokens": + return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)) + self.raise_configuration_error( + f"Unknown prompt method {repr(self.data.prompt_method)}", code=104 + ) + + def tokenize_dataset_callback( + self, tokenizer: transformers.PreTrainedTokenizerBase, text: str + ) -> List[int]: + if self.data.newlinemode == "s": + text = text.replace("\n", "") + with tokenizer._kai_no_prefix(): + return tokenizer.encode(text) + self.data.params["eos_token"]