Implement file saving in prompt_tuner.py

This commit is contained in:
vfbd 2022-08-22 16:29:39 -04:00
parent 4e88b277d4
commit 728e19a7f0
1 changed files with 89 additions and 5 deletions

View File

@ -7,6 +7,12 @@ import termcolor
import contextlib
import traceback
import random
import zipfile
import json
import uuid
import datetime
import base64
import pickle
import torch
import torch.nn.functional as F
from torch.nn import Embedding, CrossEntropyLoss
@ -244,12 +250,72 @@ class TrainerBase(abc.ABC):
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
return get_tokenizer(self.ckpt_path)
def save_data(self):
pass
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
pass
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
meta = {
"name": name,
"author": author,
"supported": supported,
"description": description,
}
if len(meta["author"].strip()) == 0:
meta.pop("author")
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
with z.open("tensor.npy", "w") as f:
np.save(f, tensor, allow_pickle=False)
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
with z.open("meta.json", "w") as f:
f.write(json.dumps(meta, indent=2).encode("utf-8"))
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
pass
try:
z = torch.load(self.data.save_file)
assert z["step"] > 0
assert z["tensor"].ndim == 2 and "opt_state" in z
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
self.data.soft_in_dim = z["tensor"].shape[0]
_step = z["step"]
except AssertionError:
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
tensor = z["tensor"]
with open(output_file, "w") as f:
json.dump(
{
"metadata": {
"step": _step,
"loss": float(z["loss"].item()),
"uuid": str(uuid.uuid4()),
"name": soft_prompt_name,
"description": soft_prompt_description,
"epoch": datetime.datetime.now().timestamp(),
},
"tensor": base64.b64encode(
pickle.dumps(
tensor,
protocol=4,
),
).decode("ascii"),
},
f,
)
def tokenize_dataset(
self,
@ -456,6 +522,23 @@ class TrainerBase(abc.ABC):
optimizer.state['step'] = step
cross_entropy_loss = CrossEntropyLoss()
def save_mkusp(
loss,
grad_norm,
):
with open(self.data.save_file, "wb") as f:
torch.save(
{
"tensor": soft_embeddings.get_inputs_embeds(),
"opt_state": optimizer.state_dict(),
"step": step,
"loss": loss,
"grad_norm": grad_norm,
},
f,
)
self.save_data()
while step < steps:
model.train()
@ -494,7 +577,8 @@ class TrainerBase(abc.ABC):
scheduler.step()
optimizer.zero_grad()
# Save checkpoint every few steps
pass
step += 1
# Save checkpoint every few steps
if step == 1 or step % self.data.stparams["save_every"] == 0:
save_mkusp(mean_loss, mean_grad_norm)