Implement file saving in prompt_tuner.py
This commit is contained in:
parent
4e88b277d4
commit
728e19a7f0
|
@ -7,6 +7,12 @@ import termcolor
|
||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
import random
|
import random
|
||||||
|
import zipfile
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Embedding, CrossEntropyLoss
|
from torch.nn import Embedding, CrossEntropyLoss
|
||||||
|
@ -245,11 +251,71 @@ class TrainerBase(abc.ABC):
|
||||||
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
||||||
return get_tokenizer(self.ckpt_path)
|
return get_tokenizer(self.ckpt_path)
|
||||||
|
|
||||||
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
|
def save_data(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str):
|
||||||
|
try:
|
||||||
|
z = torch.load(self.data.save_file)
|
||||||
|
assert z["step"] > 0
|
||||||
|
assert z["tensor"].ndim == 2 and "opt_state" in z
|
||||||
|
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
|
||||||
|
self.data.soft_in_dim = z["tensor"].shape[0]
|
||||||
|
except AssertionError:
|
||||||
|
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
|
||||||
|
|
||||||
|
tensor = z["tensor"]
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"name": name,
|
||||||
|
"author": author,
|
||||||
|
"supported": supported,
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
if len(meta["author"].strip()) == 0:
|
||||||
|
meta.pop("author")
|
||||||
|
meta["supported"] = list(map(lambda m: m.strip(), supported.split(",")))
|
||||||
|
|
||||||
|
with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z:
|
||||||
|
with z.open("tensor.npy", "w") as f:
|
||||||
|
np.save(f, tensor, allow_pickle=False)
|
||||||
|
with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z:
|
||||||
|
with z.open("meta.json", "w") as f:
|
||||||
|
f.write(json.dumps(meta, indent=2).encode("utf-8"))
|
||||||
|
|
||||||
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
|
def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str):
|
||||||
pass
|
try:
|
||||||
|
z = torch.load(self.data.save_file)
|
||||||
|
assert z["step"] > 0
|
||||||
|
assert z["tensor"].ndim == 2 and "opt_state" in z
|
||||||
|
assert z["tensor"].shape[0] < self.data.params["max_batch_size"]
|
||||||
|
self.data.soft_in_dim = z["tensor"].shape[0]
|
||||||
|
_step = z["step"]
|
||||||
|
except AssertionError:
|
||||||
|
self.raise_configuration_error("MTJSP file is corrupted.", code=14)
|
||||||
|
|
||||||
|
tensor = z["tensor"]
|
||||||
|
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"step": _step,
|
||||||
|
"loss": float(z["loss"].item()),
|
||||||
|
"uuid": str(uuid.uuid4()),
|
||||||
|
"name": soft_prompt_name,
|
||||||
|
"description": soft_prompt_description,
|
||||||
|
"epoch": datetime.datetime.now().timestamp(),
|
||||||
|
},
|
||||||
|
"tensor": base64.b64encode(
|
||||||
|
pickle.dumps(
|
||||||
|
tensor,
|
||||||
|
protocol=4,
|
||||||
|
),
|
||||||
|
).decode("ascii"),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize_dataset(
|
def tokenize_dataset(
|
||||||
self,
|
self,
|
||||||
|
@ -456,6 +522,23 @@ class TrainerBase(abc.ABC):
|
||||||
optimizer.state['step'] = step
|
optimizer.state['step'] = step
|
||||||
cross_entropy_loss = CrossEntropyLoss()
|
cross_entropy_loss = CrossEntropyLoss()
|
||||||
|
|
||||||
|
def save_mkusp(
|
||||||
|
loss,
|
||||||
|
grad_norm,
|
||||||
|
):
|
||||||
|
with open(self.data.save_file, "wb") as f:
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"tensor": soft_embeddings.get_inputs_embeds(),
|
||||||
|
"opt_state": optimizer.state_dict(),
|
||||||
|
"step": step,
|
||||||
|
"loss": loss,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
self.save_data()
|
||||||
|
|
||||||
while step < steps:
|
while step < steps:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -494,7 +577,8 @@ class TrainerBase(abc.ABC):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Save checkpoint every few steps
|
|
||||||
pass
|
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
# Save checkpoint every few steps
|
||||||
|
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||||
|
save_mkusp(mean_loss, mean_grad_norm)
|
||||||
|
|
Loading…
Reference in New Issue