mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2024-12-12 08:36:28 +01:00
Upload BasicTrainer class
This commit is contained in:
parent
728e19a7f0
commit
05cf9b1dde
@ -13,6 +13,8 @@ import uuid
|
|||||||
import datetime
|
import datetime
|
||||||
import base64
|
import base64
|
||||||
import pickle
|
import pickle
|
||||||
|
import hashlib
|
||||||
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Embedding, CrossEntropyLoss
|
from torch.nn import Embedding, CrossEntropyLoss
|
||||||
@ -582,3 +584,88 @@ class TrainerBase(abc.ABC):
|
|||||||
# Save checkpoint every few steps
|
# Save checkpoint every few steps
|
||||||
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||||
save_mkusp(mean_loss, mean_grad_norm)
|
save_mkusp(mean_loss, mean_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTrainer(TrainerBase):
|
||||||
|
class TrainerData(TrainerBase.TrainerData):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.dataset_file: Optional[str] = None
|
||||||
|
self.initial_softprompt: Optional[List[int]] = None
|
||||||
|
|
||||||
|
data: "BasicTrainer.TrainerData"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
def startup(self, step: int) -> None:
|
||||||
|
if self.get_num_sequences() < self.data.gradient_accumulation_steps:
|
||||||
|
self.raise_configuration_error(
|
||||||
|
"Your dataset is too small! gradient_accumulation_steps must be less than or equal to the number of sequences.",
|
||||||
|
code=101,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.data.prompt_method == "tokens"
|
||||||
|
and step < 0
|
||||||
|
and self.data.initial_softprompt is None
|
||||||
|
):
|
||||||
|
self.raise_configuration_error(
|
||||||
|
"You have not set an initial soft prompt string.", code=103
|
||||||
|
)
|
||||||
|
if self.data.prompt_method == "tokens" and step < 0:
|
||||||
|
self.data.soft_in_dim = len(self.data.initial_softprompt)
|
||||||
|
|
||||||
|
def get_batch(self, step: int, size: int) -> np.ndarray:
|
||||||
|
return self.dataset[(step - 1) * size : step * size]
|
||||||
|
|
||||||
|
def get_num_sequences(self) -> int:
|
||||||
|
if self.dataset is None:
|
||||||
|
if self.data.dataset_file is None or not os.path.exists(
|
||||||
|
self.data.dataset_file
|
||||||
|
):
|
||||||
|
self.raise_configuration_error(
|
||||||
|
f"Dataset file not found at {repr(self.data.dataset_file)}",
|
||||||
|
code=102,
|
||||||
|
)
|
||||||
|
self.dataset = np.load(self.data.dataset_file, mmap_mode="r")
|
||||||
|
assert self.dataset.ndim >= 2
|
||||||
|
assert self.dataset.shape[0] >= 2
|
||||||
|
return self.dataset.shape[0]
|
||||||
|
|
||||||
|
def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt:
|
||||||
|
if self.data.prompt_method == "vocab_sample":
|
||||||
|
rng = np.random.Generator(
|
||||||
|
np.random.PCG64(
|
||||||
|
[
|
||||||
|
self.data.prompt_seed,
|
||||||
|
int.from_bytes(hashlib.sha256(model.config.model_type.encode("utf8")).digest()[:4], "little"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
with tokenizer._kai_no_prefix():
|
||||||
|
special_tokens = set(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
|
tokenizer.encode(str(v))
|
||||||
|
for v in tokenizer.special_tokens_map_extended.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sample_space = [
|
||||||
|
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
|
||||||
|
]
|
||||||
|
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
||||||
|
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
|
||||||
|
elif self.data.prompt_method == "tokens":
|
||||||
|
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
|
||||||
|
self.raise_configuration_error(
|
||||||
|
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize_dataset_callback(
|
||||||
|
self, tokenizer: transformers.PreTrainedTokenizerBase, text: str
|
||||||
|
) -> List[int]:
|
||||||
|
if self.data.newlinemode == "s":
|
||||||
|
text = text.replace("\n", "</s>")
|
||||||
|
with tokenizer._kai_no_prefix():
|
||||||
|
return tokenizer.encode(text) + self.data.params["eos_token"]
|
||||||
|
Loading…
Reference in New Issue
Block a user