KoboldAI-Client/prompt_tuner.py

672 lines
27 KiB
Python

import abc
import os
import sys
import math
import numpy as np
import termcolor
import contextlib
import traceback
import random
import zipfile
import json
import uuid
import datetime
import base64
import pickle
import hashlib
import itertools
import torch
import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss
import transformers
from transformers import AutoTokenizer, GPT2TokenizerFast
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
from mkultra.soft_prompt import SoftPrompt
from typing import List, Optional, TextIO, Union
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
class _WTEMixin:
@property
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
return self.get_input_embeddings()
@wte.setter
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel], v):
self.set_input_embeddings(v)
class UniversalPromptTuningMixin:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
if not hasattr(model, "transformer"):
model.transformer = _WTEMixin()
elif not hasattr(model.transformer, "wte"):
assert isinstance(model.transformer, type)
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
model.__class__ = type("_UniversalPromptTuning" + model.__class__.__name__, (UniversalPromptTuningMixin, model.__class__), {})
for param in model.parameters():
param.requires_grad = False
model.initialize_soft_prompt()
return model
def forward(
self: _PromptTuningPreTrainedModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
assert input_ids is not None
assert input_ids.ndim == 2
input_ids = F.pad(input_ids, (self.learned_embedding.size(0), 0, 0, 0), value=self.transformer.wte.weight.size(0) // 2)
if labels is not None:
labels = self._extend_labels(labels)
if attention_mask is not None:
attention_mask = self._extend_attention_mask(attention_mask)
old_embedding_call = Embedding.__call__
model = self
def new_embedding_call(self, input_ids, *args, **kwargs):
inputs_embeds = old_embedding_call(self, input_ids, *args, **kwargs)
if model.transformer.wte is self:
assert inputs_embeds.ndim == 3
inputs_embeds[:, :model.learned_embedding.size(0), :] = model.learned_embedding[None]
return inputs_embeds
Embedding.__call__ = new_embedding_call
try:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
use_cache=use_cache,
return_dict=return_dict,
)
finally:
Embedding.__call__ = old_embedding_call
for k in dir(GPTPromptTuningMixin):
v = getattr(GPTPromptTuningMixin, k)
_v = getattr(UniversalPromptTuningMixin, k, None)
if _v is None or (_v is getattr(object, k, None) and callable(_v) and not isinstance(_v, type)):
setattr(UniversalPromptTuningMixin, k, v)
class AutoPromptTuningLM(UniversalPromptTuningMixin, transformers.AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
default_quiet = False
def get_tokenizer(model_id, revision=None) -> transformers.PreTrainedTokenizerBase:
if(os.path.isdir(model_id)):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_id, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=revision, cache_dir="cache")
except Exception as e:
pass
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=revision, cache_dir="cache", use_fast=False)
except Exception as e:
try:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=revision, cache_dir="cache")
except Exception as e:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=revision, cache_dir="cache")
@contextlib.contextmanager
def _kai_no_prefix():
add_bos_token = getattr(tokenizer, "add_bos_token", False)
add_prefix_space = getattr(tokenizer, "add_prefix_space", False)
tokenizer.add_bos_token = False
tokenizer.add_prefix_space = False
try:
yield
finally:
tokenizer.add_bos_token = add_bos_token
tokenizer.add_prefix_space = add_prefix_space
tokenizer._kai_no_prefix = _kai_no_prefix
return tokenizer
class ConfigurationError(Exception):
def __init__(self, msg: str = "Unknown error", code: int = 1, quiet: Optional[bool] = None):
if quiet is None:
quiet = default_quiet
super().__init__(msg)
self.code = code
self.quiet = quiet
class TrainerBase(abc.ABC):
@abc.abstractmethod
def startup(self, step: int) -> None:
...
@abc.abstractmethod
def get_batch(self, step: int, size: int) -> np.ndarray:
...
@abc.abstractmethod
def get_num_sequences(self) -> int:
...
@abc.abstractmethod
def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt:
...
@abc.abstractmethod
def tokenize_dataset_callback(self, tokenizer: transformers.PreTrainedTokenizerBase, text: str) -> List[int]:
...
class TrainerData:
def __init__(self):
self.__lazy_load_spec: Optional[dict] = None
self.model_spec: Optional[dict] = None
self.tokenizer_id: Optional[str] = None
self.newlinemode: Optional[str] = None
self.ckpt_path: Optional[str] = None
self.save_file: Optional[str] = None
self.params: Optional[dict] = None
self.stparams: Optional[dict] = None
self.gradient_accumulation_steps = -1
self.soft_in_dim = -1
self.prompt_method = "tokens"
self.prompt_seed = 42
@property
def lazy_load_spec(self):
print("WARNING: `TrainerData.lazy_load_spec` is currently unused", file=sys.stderr)
return self.__lazy_load_spec
@lazy_load_spec.setter
def lazy_load_spec(self, value: Optional[dict]):
print("WARNING: `TrainerData.lazy_load_spec` is currently unused", file=sys.stderr)
self.__lazy_load_spec = value
@property
def kaiming_size(self): # backwards compatibility
return self.soft_in_dim
@kaiming_size.setter
def kaiming_size(self, value: int): # backwards compatibility
self.prompt_method = "kaiming"
self.soft_in_dim = value
data: TrainerData
def __init__(self, universe: Optional[int] = None, quiet=False):
self.quiet = quiet
self.universe = universe
self.data = self.TrainerData()
self._spmodule: Optional[str] = None
if universe is not None:
print("WARNING: The `universe` argument of `TrainerBase.__init__` is currently unused", file=sys.stderr)
def raise_configuration_error(self, msg, **kwargs):
if "quiet" not in kwargs:
kwargs["quiet"] = self.quiet
raise ConfigurationError(msg, **kwargs)
def get_hf_checkpoint_metadata(self) -> bool:
return True
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return get_tokenizer(self.ckpt_path)
def save_data(self):
pass
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
meta = {
"name": name,
"author": author,
"supported": supported,
"description": description,
}
if len(meta["author"].strip()) == 0:
meta.pop("author")
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
with z.open("tensor.npy", "w") as f:
np.save(f, tensor, allow_pickle=False)
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
with z.open("meta.json", "w") as f:
f.write(json.dumps(meta, indent=2).encode("utf-8"))
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
_step = z["step"]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
with open(output_file, "w") as f:
json.dump(
{
"metadata": {
"step": _step,
"loss": float(z["loss"].item()),
"uuid": str(uuid.uuid4()),
"name": soft_prompt_name,
"description": soft_prompt_description,
"epoch": datetime.datetime.now().timestamp(),
},
"tensor": base64.b64encode(
pickle.dumps(
tensor,
protocol=4,
),
).decode("ascii"),
},
f,
)
def tokenize_dataset(
self,
dataset_path: Union[str, TextIO],
output_file: Union[str, TextIO],
batch_size=2048,
epochs=1,
use_ftfy=True,
shuffle_seed: Optional[Union[int, float, str, bytes, bytearray]] = 1729,
):
dataset_path = dataset_path.replace("\\", "/")
output_file = output_file.replace("\\", "/")
if not isinstance(batch_size, int) or batch_size < 1:
self.raise_configuration_error(
"batch_size must be an integer greater than zero.", code=9
)
if (
not isinstance(epochs, int) and not isinstance(epochs, float)
) or epochs <= 0:
self.raise_configuration_error(
"epochs must be an int or float greater than zero.", code=10
)
if isinstance(output_file, str) and output_file.endswith("/"):
self.raise_configuration_error(
"output_file should be the path to a file, not a directory.", code=11
)
if isinstance(dataset_path, str) and not os.path.exists(dataset_path):
self.raise_configuration_error(
"dataset_path is not set to a valid file or directory.", code=12
)
if use_ftfy:
import ftfy
tokenizer = self.get_tokenizer()
batch_size = min(
batch_size,
self.data.params["max_batch_size"] - self.data.soft_in_dim,
)
assert batch_size >= 0
print(
termcolor.colored(
"\nIf you see a warning somewhere below about token indices, ignore it. That warning is normal.\n",
"magenta",
)
)
print("Batch size:", batch_size)
print(termcolor.colored("Tokenizing your dataset...\n", "magenta"))
if not isinstance(dataset_path, str):
files = [dataset_path]
elif os.path.isfile(dataset_path):
files = [dataset_path]
else:
files = sorted(
os.path.join(dataset_path, filename)
for filename in os.listdir(dataset_path)
)
if shuffle_seed is not None:
random.Random(shuffle_seed).shuffle(files)
tokens = []
eos = tokenizer.decode(self.data.params["eos_token"])
for path in files:
if isinstance(path, str):
f = open(path)
else:
f = path
try:
text = f.read()
if use_ftfy:
text = ftfy.fix_text(text)
text = text.replace("<|endoftext|>", eos)
tokens.extend(self.tokenize_dataset_callback(tokenizer, text))
finally:
if isinstance(path, str):
f.close()
print("Dataset size (in tokens):", len(tokens))
if len(tokens) < batch_size + 1:
self.raise_configuration_error(
"Your dataset is too small! The number of tokens has to be greater than the batch size. Try increasing the epochs.",
code=13,
)
tail = len(tokens) % (batch_size + 1)
if tail:
print(
f"We're removing the last {tail} tokens from your dataset to make the length a multiple of {batch_size+1}."
)
tokens = tokens[:-tail]
tokens = np.array(tokens, dtype=np.uint16).reshape((-1, batch_size + 1))
sequences_per_epoch = tokens.shape[0]
_epochs = math.ceil(epochs)
if _epochs > 1:
rng = np.random.Generator(np.random.PCG64(1729))
tokens = np.concatenate(
(
tokens,
*(rng.permutation(tokens, axis=0) for i in range(_epochs - 1)),
),
axis=0,
)
tokens = tokens[: math.ceil(epochs * sequences_per_epoch)]
print(f"Total sequences in your dataset: {tokens.shape[0]}")
if isinstance(output_file, str):
f = open(output_file, "w")
else:
f = output_file
try:
np.save(output_file, tokens)
finally:
if isinstance(output_file, str):
f.close()
def train(self):
if self.data.params is not None and "max_batch_size" not in self.data.params:
self.data.params["max_batch_size"] = 2048
if not os.path.exists(self.data.save_file):
print("We are starting a brand new soft-tuning session.\n")
self.startup(step=-1)
if self.data.soft_in_dim <= 0:
self.raise_configuration_error(
"You have not set a soft prompt size.", code=6
)
else:
# If we're resuming a soft-tuning session, the soft prompt tensor is
# already in the save file and we just have to decode it.
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
step = z["step"]
opt_state = z["opt_state"]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
print(f"We're resuming a previous soft-tuning session at step {step+1}.\n")
self.startup(step=step + 1)
soft_embeddings = z["tensor"]
REVISION = None
tokenizer = self.get_tokenizer()
model: _PromptTuningPreTrainedModel
if(os.path.isdir(self.data.ckpt_path)):
try:
model = AutoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
try:
model = AutoPromptTuningLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=REVISION, cache_dir="cache")
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=REVISION, cache_dir="cache")
else:
try:
model = AutoPromptTuningLM.from_pretrained(vars.model, revision=REVISION, cache_dir="cache")
except Exception as e:
if("out of memory" in traceback.format_exc().lower()):
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
model = GPTNeoPromptTuningLM.from_pretrained(vars.model, revision=REVISION, cache_dir="cache")
if step == 0:
soft_embeddings = self.get_initial_soft_embeddings(model)
else:
soft_embeddings = SoftPrompt.from_inputs_embeds(soft_embeddings)
model.set_soft_prompt(soft_embeddings)
steps = self.get_num_sequences() // self.data.gradient_accumulation_steps
warmup_steps = max(1, round(steps * self.data.stparams["warmup"]))
beta1: Optional[float] = self.data.stparams.get("beta1", 0.0)
if beta1 == 0.0:
beta1 = None
optimizer = transformers.Adafactor(
params=model.get_soft_params(),
scale_parameter=False,
relative_step=False,
warmup_init=False,
lr=self.data.stparams["lr"],
beta1=beta1,
decay_rate=self.data.stparams.get("decay_rate", -0.8),
weight_decay=self.data.stparams.get("weight_decay", 0.1),
)
if step != 0:
optimizer.load_state_dict(opt_state)
scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=steps - warmup_steps,
num_cycles=(steps - warmup_steps) // self.data.stparams.get("training_steps_per_cycle", 56),
)
torch.cuda.empty_cache()
optimizer.state['step'] = step
cross_entropy_loss = CrossEntropyLoss()
def save_mkusp(
loss,
grad_norm,
):
with open(self.data.save_file, "wb") as f:
torch.save(
{
"tensor": soft_embeddings.get_inputs_embeds(),
"opt_state": optimizer.state_dict(),
"step": step,
"loss": loss,
"grad_norm": grad_norm,
},
f,
)
self.save_data()
while step < steps:
model.train()
total_loss = total_grad = total_grad_norm = 0
for i in range(self.data.gradient_accumulation_steps):
# Get the next sequence from the dataset
block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device)
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
input_ids = block[:-1].unsqueeze(0).detach()
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach()
# Give the context to the model and compare the model's output logits with the labels to compute the loss
logits = model(input_ids=input_ids, labels=input_ids).logits
loss: torch.Tensor = cross_entropy_loss(logits.view(-1, model.transformer.wte.weight.size(1)), labels.view(-1))
total_loss += loss.detach()
# Compute the gradient of the loss function and add it to model.get_soft_params().grad (model.get_soft_params().grad += gradient)
loss.backward()
total_grad_norm += torch.linalg.norm(model.get_soft_params().grad.detach() - total_grad)
total_grad = model.get_soft_params().grad.detach()
del input_ids
del labels
del logits
torch.cuda.empty_cache()
mean_loss = (total_loss / self.data.gradient_accumulation_steps).item()
mean_grad_norm = (total_grad_norm / self.data.gradient_accumulation_steps).item()
# Apply the optimization algorithm using the accumulated gradients, which changes the contents of the soft prompt matrix very slightly to reduce the loss
optimizer.step()
lr = optimizer.param_groups[0]["lr"]
scheduler.step()
optimizer.zero_grad()
step += 1
# Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm)
class BasicTrainer(TrainerBase):
class TrainerData(TrainerBase.TrainerData):
def __init__(self):
super().__init__()
self.dataset_file: Optional[str] = None
self.initial_softprompt: Optional[List[int]] = None
data: "BasicTrainer.TrainerData"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset: Optional[np.ndarray] = None
def startup(self, step: int) -> None:
if self.get_num_sequences() < self.data.gradient_accumulation_steps:
self.raise_configuration_error(
"Your dataset is too small! gradient_accumulation_steps must be less than or equal to the number of sequences.",
code=101,
)
if (
self.data.prompt_method == "tokens"
and step < 0
and self.data.initial_softprompt is None
):
self.raise_configuration_error(
"You have not set an initial soft prompt string.", code=103
)
if self.data.prompt_method == "tokens" and step < 0:
self.data.soft_in_dim = len(self.data.initial_softprompt)
def get_batch(self, step: int, size: int) -> np.ndarray:
return self.dataset[(step - 1) * size : step * size]
def get_num_sequences(self) -> int:
if self.dataset is None:
if self.data.dataset_file is None or not os.path.exists(
self.data.dataset_file
):
self.raise_configuration_error(
f"Dataset file not found at {repr(self.data.dataset_file)}",
code=102,
)
self.dataset = np.load(self.data.dataset_file, mmap_mode="r")
assert self.dataset.ndim >= 2
assert self.dataset.shape[0] >= 2
return self.dataset.shape[0]
def get_initial_soft_embeddings(self, model: transformers.PreTrainedModel) -> SoftPrompt:
if self.data.prompt_method == "vocab_sample":
rng = np.random.Generator(
np.random.PCG64(
[
self.data.prompt_seed,
int.from_bytes(hashlib.sha256(model.config.model_type.encode("utf8")).digest()[:4], "little"),
]
)
)
tokenizer = self.get_tokenizer()
with tokenizer._kai_no_prefix():
special_tokens = set(
itertools.chain.from_iterable(
tokenizer.encode(str(v))
for v in tokenizer.special_tokens_map_extended.values()
)
)
sample_space = [
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
]
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
elif self.data.prompt_method == "tokens":
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
self.raise_configuration_error(
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
)
def tokenize_dataset_callback(
self, tokenizer: transformers.PreTrainedTokenizerBase, text: str
) -> List[int]:
if self.data.newlinemode == "s":
text = text.replace("\n", "</s>")
with tokenizer._kai_no_prefix():
return tokenizer.encode(text) + self.data.params["eos_token"]