From 728e19a7f078550703751b392c6a7c79c08b7137 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 22 Aug 2022 16:29:39 -0400 Subject: [PATCH] Implement file saving in prompt_tuner.py --- prompt_tuner.py | 94 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/prompt_tuner.py b/prompt_tuner.py index a35b12f1..fbecb4c4 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -7,6 +7,12 @@ import termcolor import contextlib import traceback import random +import zipfile +import json +import uuid +import datetime +import base64 +import pickle import torch import torch.nn.functional as F from torch.nn import Embedding, CrossEntropyLoss @@ -244,12 +250,72 @@ class TrainerBase(abc.ABC): def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase: return get_tokenizer(self.ckpt_path) + + def save_data(self): + pass def export_to_kobold(self, output_file: str, name: str, author: str, supported: str, description: str): - pass + try: + z = torch.load(self.data.save_file) + assert z["step"] > 0 + assert z["tensor"].ndim == 2 and "opt_state" in z + assert z["tensor"].shape[0] < self.data.params["max_batch_size"] + self.data.soft_in_dim = z["tensor"].shape[0] + except AssertionError: + self.raise_configuration_error("MTJSP file is corrupted.", code=14) + + tensor = z["tensor"] + + meta = { + "name": name, + "author": author, + "supported": supported, + "description": description, + } + if len(meta["author"].strip()) == 0: + meta.pop("author") + meta["supported"] = list(map(lambda m: m.strip(), supported.split(","))) + + with zipfile.ZipFile(output_file, "w", compression=zipfile.ZIP_LZMA) as z: + with z.open("tensor.npy", "w") as f: + np.save(f, tensor, allow_pickle=False) + with zipfile.ZipFile(output_file, "a", compression=zipfile.ZIP_STORED) as z: + with z.open("meta.json", "w") as f: + f.write(json.dumps(meta, indent=2).encode("utf-8")) def export_to_mkultra(self, output_file: str, soft_prompt_name: str, soft_prompt_description: str): - pass + try: + z = torch.load(self.data.save_file) + assert z["step"] > 0 + assert z["tensor"].ndim == 2 and "opt_state" in z + assert z["tensor"].shape[0] < self.data.params["max_batch_size"] + self.data.soft_in_dim = z["tensor"].shape[0] + _step = z["step"] + except AssertionError: + self.raise_configuration_error("MTJSP file is corrupted.", code=14) + + tensor = z["tensor"] + + with open(output_file, "w") as f: + json.dump( + { + "metadata": { + "step": _step, + "loss": float(z["loss"].item()), + "uuid": str(uuid.uuid4()), + "name": soft_prompt_name, + "description": soft_prompt_description, + "epoch": datetime.datetime.now().timestamp(), + }, + "tensor": base64.b64encode( + pickle.dumps( + tensor, + protocol=4, + ), + ).decode("ascii"), + }, + f, + ) def tokenize_dataset( self, @@ -456,6 +522,23 @@ class TrainerBase(abc.ABC): optimizer.state['step'] = step cross_entropy_loss = CrossEntropyLoss() + def save_mkusp( + loss, + grad_norm, + ): + with open(self.data.save_file, "wb") as f: + torch.save( + { + "tensor": soft_embeddings.get_inputs_embeds(), + "opt_state": optimizer.state_dict(), + "step": step, + "loss": loss, + "grad_norm": grad_norm, + }, + f, + ) + self.save_data() + while step < steps: model.train() @@ -494,7 +577,8 @@ class TrainerBase(abc.ABC): scheduler.step() optimizer.zero_grad() - # Save checkpoint every few steps - pass - step += 1 + + # Save checkpoint every few steps + if step == 1 or step % self.data.stparams["save_every"] == 0: + save_mkusp(mean_loss, mean_grad_norm)