Upload BasicTrainer class

This commit is contained in:
vfbd 2022-08-22 16:43:02 -04:00
parent 728e19a7f0
commit 05cf9b1dde

View File

@ -13,6 +13,8 @@ import uuid
import datetime import datetime
import base64 import base64
import pickle import pickle
import hashlib
import itertools
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss from torch.nn import Embedding, CrossEntropyLoss
@ -582,3 +584,88 @@ class TrainerBase(abc.ABC):
# Save checkpoint every few steps # Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0: if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm) save_mkusp(mean_loss, mean_grad_norm)
class BasicTrainer(TrainerBase):
class TrainerData(TrainerBase.TrainerData):
def __init__(self):
super().__init__()
self.dataset_file: Optional[str] = None
self.initial_softprompt: Optional[List[int]] = None
data: "BasicTrainer.TrainerData"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset: Optional[np.ndarray] = None
def startup(self, step: int) -> None:
if self.get_num_sequences() < self.data.gradient_accumulation_steps:
self.raise_configuration_error(
"Your dataset is too small! gradient_accumulation_steps must be less than or equal to the number of sequences.",
code=101,
)
if (
self.data.prompt_method == "tokens"
and step < 0
and self.data.initial_softprompt is None
):
self.raise_configuration_error(
"You have not set an initial soft prompt string.", code=103
)
if self.data.prompt_method == "tokens" and step < 0:
self.data.soft_in_dim = len(self.data.initial_softprompt)
def get_batch(self, step: int, size: int) -> np.ndarray:
return self.dataset[(step - 1) * size : step * size]
def get_num_sequences(self) -> int:
if self.dataset is None:
if self.data.dataset_file is None or not os.path.exists(
self.data.dataset_file
):
self.raise_configuration_error(
f"Dataset file not found at {repr(self.data.dataset_file)}",
code=102,
)
self.dataset = np.load(self.data.dataset_file, mmap_mode="r")
assert self.dataset.ndim >= 2
assert self.dataset.shape[0] >= 2
return self.dataset.shape[0]
def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt:
if self.data.prompt_method == "vocab_sample":
rng = np.random.Generator(
np.random.PCG64(
[
self.data.prompt_seed,
int.from_bytes(hashlib.sha256(model.config.model_type.encode("utf8")).digest()[:4], "little"),
]
)
)
tokenizer = self.get_tokenizer()
with tokenizer._kai_no_prefix():
special_tokens = set(
itertools.chain.from_iterable(
tokenizer.encode(str(v))
for v in tokenizer.special_tokens_map_extended.values()
)
)
sample_space = [
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
]
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
elif self.data.prompt_method == "tokens":
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
self.raise_configuration_error(
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
)
def tokenize_dataset_callback(
self, tokenizer: transformers.PreTrainedTokenizerBase, text: str
) -> List[int]:
if self.data.newlinemode == "s":
text = text.replace("\n", "</s>")
with tokenizer._kai_no_prefix():
return tokenizer.encode(text) + self.data.params["eos_token"]