Implement file saving in prompt_tuner.py
This commit is contained in:
parent
4e88b277d4
commit
728e19a7f0
|
@ -7,6 +7,12 @@ import termcolor
|
|||
import contextlib
|
||||
import traceback
|
||||
import random
|
||||
import zipfile
|
||||
import json
|
||||
import uuid
|
||||
import datetime
|
||||
import base64
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Embedding, CrossEntropyLoss
|
||||
|
@ -244,12 +250,72 @@ class TrainerBase(abc.ABC):
|
|||
|
||||
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
||||
return get_tokenizer(self.ckpt_path)
|
||||
|
||||
def save_data(self):
|
||||
pass
|
||||
|
||||
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
|
||||
pass
|
||||
try:
|
||||
z = torch.load(self.data.save_file)
|
||||
assert z["step"] > 0
|
||||
assert z["tensor"].ndim == 2 and "opt_state" in z
|
||||
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
|
||||
self.data.soft_in_dim = z["tensor"].shape[0]
|
||||
except AssertionError:
|
||||
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
|
||||
|
||||
tensor = z["tensor"]
|
||||
|
||||
meta = {
|
||||
"name": name,
|
||||
"author": author,
|
||||
"supported": supported,
|
||||
"description": description,
|
||||
}
|
||||
if len(meta["author"].strip()) == 0:
|
||||
meta.pop("author")
|
||||
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))
|
||||
|
||||
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
|
||||
with z.open("tensor.npy", "w") as f:
|
||||
np.save(f, tensor, allow_pickle=False)
|
||||
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
|
||||
with z.open("meta.json", "w") as f:
|
||||
f.write(json.dumps(meta, indent=2).encode("utf-8"))
|
||||
|
||||
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
|
||||
pass
|
||||
try:
|
||||
z = torch.load(self.data.save_file)
|
||||
assert z["step"] > 0
|
||||
assert z["tensor"].ndim == 2 and "opt_state" in z
|
||||
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
|
||||
self.data.soft_in_dim = z["tensor"].shape[0]
|
||||
_step = z["step"]
|
||||
except AssertionError:
|
||||
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
|
||||
|
||||
tensor = z["tensor"]
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"metadata": {
|
||||
"step": _step,
|
||||
"loss": float(z["loss"].item()),
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"name": soft_prompt_name,
|
||||
"description": soft_prompt_description,
|
||||
"epoch": datetime.datetime.now().timestamp(),
|
||||
},
|
||||
"tensor": base64.b64encode(
|
||||
pickle.dumps(
|
||||
tensor,
|
||||
protocol=4,
|
||||
),
|
||||
).decode("ascii"),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
def tokenize_dataset(
|
||||
self,
|
||||
|
@ -456,6 +522,23 @@ class TrainerBase(abc.ABC):
|
|||
optimizer.state['step'] = step
|
||||
cross_entropy_loss = CrossEntropyLoss()
|
||||
|
||||
def save_mkusp(
|
||||
loss,
|
||||
grad_norm,
|
||||
):
|
||||
with open(self.data.save_file, "wb") as f:
|
||||
torch.save(
|
||||
{
|
||||
"tensor": soft_embeddings.get_inputs_embeds(),
|
||||
"opt_state": optimizer.state_dict(),
|
||||
"step": step,
|
||||
"loss": loss,
|
||||
"grad_norm": grad_norm,
|
||||
},
|
||||
f,
|
||||
)
|
||||
self.save_data()
|
||||
|
||||
while step < steps:
|
||||
model.train()
|
||||
|
||||
|
@ -494,7 +577,8 @@ class TrainerBase(abc.ABC):
|
|||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint every few steps
|
||||
pass
|
||||
|
||||
step += 1
|
||||
|
||||
# Save checkpoint every few steps
|
||||
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||
save_mkusp(mean_loss, mean_grad_norm)
|
||||
|
|
Loading…
Reference in New Issue