Implement file saving in prompt_tuner.py

This commit is contained in:
vfbd 2022-08-22 16:29:39 -04:00
parent 4e88b277d4
commit 728e19a7f0
1 changed files with 89 additions and 5 deletions

View File

@ -7,6 +7,12 @@ import termcolor
import contextlib import contextlib
import traceback import traceback
import random import random
import zipfile
import json
import uuid
import datetime
import base64
import pickle
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss from torch.nn import Embedding, CrossEntropyLoss
@ -245,11 +251,71 @@ class TrainerBase(abc.ABC):
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase: def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return get_tokenizer(self.ckpt_path) return get_tokenizer(self.ckpt_path)
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str): def save_data(self):
pass pass
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
meta = {
"name": name,
"author": author,
"supported": supported,
"description": description,
}
if len(meta["author"].strip()) == 0:
meta.pop("author")
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
with z.open("tensor.npy", "w") as f:
np.save(f, tensor, allow_pickle=False)
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
with z.open("meta.json", "w") as f:
f.write(json.dumps(meta, indent=2).encode("utf-8"))
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str): def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
pass try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
_step = z["step"]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
with open(output_file, "w") as f:
json.dump(
{
"metadata": {
"step": _step,
"loss": float(z["loss"].item()),
"uuid": str(uuid.uuid4()),
"name": soft_prompt_name,
"description": soft_prompt_description,
"epoch": datetime.datetime.now().timestamp(),
},
"tensor": base64.b64encode(
pickle.dumps(
tensor,
protocol=4,
),
).decode("ascii"),
},
f,
)
def tokenize_dataset( def tokenize_dataset(
self, self,
@ -456,6 +522,23 @@ class TrainerBase(abc.ABC):
optimizer.state['step'] = step optimizer.state['step'] = step
cross_entropy_loss = CrossEntropyLoss() cross_entropy_loss = CrossEntropyLoss()
def save_mkusp(
loss,
grad_norm,
):
with open(self.data.save_file, "wb") as f:
torch.save(
{
"tensor": soft_embeddings.get_inputs_embeds(),
"opt_state": optimizer.state_dict(),
"step": step,
"loss": loss,
"grad_norm": grad_norm,
},
f,
)
self.save_data()
while step < steps: while step < steps:
model.train() model.train()
@ -494,7 +577,8 @@ class TrainerBase(abc.ABC):
scheduler.step() scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# Save checkpoint every few steps
pass
step += 1 step += 1
# Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm)