Upload BasicTrainer class
This commit is contained in:
parent
728e19a7f0
commit
05cf9b1dde
|
@ -13,6 +13,8 @@ import uuid
|
|||
import datetime
|
||||
import base64
|
||||
import pickle
|
||||
import hashlib
|
||||
import itertools
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Embedding, CrossEntropyLoss
|
||||
|
@ -582,3 +584,88 @@ class TrainerBase(abc.ABC):
|
|||
# Save checkpoint every few steps
|
||||
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||
save_mkusp(mean_loss, mean_grad_norm)
|
||||
|
||||
|
||||
class BasicTrainer(TrainerBase):
|
||||
class TrainerData(TrainerBase.TrainerData):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dataset_file: Optional[str] = None
|
||||
self.initial_softprompt: Optional[List[int]] = None
|
||||
|
||||
data: "BasicTrainer.TrainerData"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset: Optional[np.ndarray] = None
|
||||
|
||||
def startup(self, step: int) -> None:
|
||||
if self.get_num_sequences() < self.data.gradient_accumulation_steps:
|
||||
self.raise_configuration_error(
|
||||
"Your dataset is too small! gradient_accumulation_steps must be less than or equal to the number of sequences.",
|
||||
code=101,
|
||||
)
|
||||
if (
|
||||
self.data.prompt_method == "tokens"
|
||||
and step < 0
|
||||
and self.data.initial_softprompt is None
|
||||
):
|
||||
self.raise_configuration_error(
|
||||
"You have not set an initial soft prompt string.", code=103
|
||||
)
|
||||
if self.data.prompt_method == "tokens" and step < 0:
|
||||
self.data.soft_in_dim = len(self.data.initial_softprompt)
|
||||
|
||||
def get_batch(self, step: int, size: int) -> np.ndarray:
|
||||
return self.dataset[(step - 1) * size : step * size]
|
||||
|
||||
def get_num_sequences(self) -> int:
|
||||
if self.dataset is None:
|
||||
if self.data.dataset_file is None or not os.path.exists(
|
||||
self.data.dataset_file
|
||||
):
|
||||
self.raise_configuration_error(
|
||||
f"Dataset file not found at {repr(self.data.dataset_file)}",
|
||||
code=102,
|
||||
)
|
||||
self.dataset = np.load(self.data.dataset_file, mmap_mode="r")
|
||||
assert self.dataset.ndim >= 2
|
||||
assert self.dataset.shape[0] >= 2
|
||||
return self.dataset.shape[0]
|
||||
|
||||
def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt:
|
||||
if self.data.prompt_method == "vocab_sample":
|
||||
rng = np.random.Generator(
|
||||
np.random.PCG64(
|
||||
[
|
||||
self.data.prompt_seed,
|
||||
int.from_bytes(hashlib.sha256(model.config.model_type.encode("utf8")).digest()[:4], "little"),
|
||||
]
|
||||
)
|
||||
)
|
||||
tokenizer = self.get_tokenizer()
|
||||
with tokenizer._kai_no_prefix():
|
||||
special_tokens = set(
|
||||
itertools.chain.from_iterable(
|
||||
tokenizer.encode(str(v))
|
||||
for v in tokenizer.special_tokens_map_extended.values()
|
||||
)
|
||||
)
|
||||
sample_space = [
|
||||
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
|
||||
]
|
||||
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
||||
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
|
||||
elif self.data.prompt_method == "tokens":
|
||||
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
|
||||
self.raise_configuration_error(
|
||||
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
||||
)
|
||||
|
||||
def tokenize_dataset_callback(
|
||||
self, tokenizer: transformers.PreTrainedTokenizerBase, text: str
|
||||
) -> List[int]:
|
||||
if self.data.newlinemode == "s":
|
||||
text = text.replace("\n", "</s>")
|
||||
with tokenizer._kai_no_prefix():
|
||||
return tokenizer.encode(text) + self.data.params["eos_token"]
|
||||
|
|
Loading…
Reference in New Issue