mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-20 20:38:21 +01:00
Fix remaining problems in prompt_tuner.py
This commit is contained in:
parent
f79926b73d
commit
584056b6d5
@ -15,11 +15,12 @@ import base64
|
||||
import pickle
|
||||
import hashlib
|
||||
import itertools
|
||||
from tqdm.auto import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Embedding, CrossEntropyLoss
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig
|
||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||
from mkultra.soft_prompt import SoftPrompt
|
||||
from typing import List, Optional, TextIO, Union
|
||||
@ -27,6 +28,18 @@ from typing import List, Optional, TextIO, Union
|
||||
|
||||
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
||||
|
||||
class _WTEDummy:
|
||||
def __init__(self, model: transformers.PreTrainedModel):
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def wte(self: "_WTEDummy"):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
@wte.setter
|
||||
def wte(self: "_WTEDummy", v):
|
||||
self.model.set_input_embeddings(v)
|
||||
|
||||
class _WTEMixin:
|
||||
@property
|
||||
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
|
||||
@ -43,7 +56,7 @@ class UniversalPromptTuningMixin:
|
||||
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if not hasattr(model, "transformer"):
|
||||
model.transformer = _WTEMixin()
|
||||
model.transformer = _WTEDummy(model)
|
||||
elif not hasattr(model.transformer, "wte"):
|
||||
assert isinstance(model.transformer, type)
|
||||
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
|
||||
@ -248,6 +261,26 @@ class TrainerBase(abc.ABC):
|
||||
raise ConfigurationError(msg, **kwargs)
|
||||
|
||||
def get_hf_checkpoint_metadata(self) -> bool:
|
||||
REVISION = None
|
||||
params = {}
|
||||
if(os.path.isdir(self.data.ckpt_path)):
|
||||
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))):
|
||||
model_config = AutoConfig.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
|
||||
else:
|
||||
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||
params["tokenizer_id"] = self.data.ckpt_path
|
||||
tokenizer = get_tokenizer(self.data.ckpt_path)
|
||||
params["newlinemode"] = params.get(
|
||||
"newlinemode", "s" if model_config.model_type == "xglm" else "n"
|
||||
)
|
||||
params["max_batch_size"] = 2048
|
||||
with tokenizer._kai_no_prefix():
|
||||
params["eos_token"] = (
|
||||
[50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id)
|
||||
)
|
||||
params["seq"] = 2048
|
||||
self.data.params = params
|
||||
return True
|
||||
|
||||
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
||||
@ -445,6 +478,7 @@ class TrainerBase(abc.ABC):
|
||||
self.raise_configuration_error(
|
||||
"You have not set a soft prompt size.", code=6
|
||||
)
|
||||
step = 0
|
||||
else:
|
||||
# If we're resuming a soft-tuning session, the soft prompt tensor is
|
||||
# already in the save file and we just have to decode it.
|
||||
@ -502,7 +536,7 @@ class TrainerBase(abc.ABC):
|
||||
if beta1 == 0.0:
|
||||
beta1 = None
|
||||
optimizer = transformers.Adafactor(
|
||||
params=model.get_soft_params(),
|
||||
params=(model.get_soft_params(),),
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
@ -540,19 +574,22 @@ class TrainerBase(abc.ABC):
|
||||
f,
|
||||
)
|
||||
self.save_data()
|
||||
|
||||
bar1 = tqdm(initial=step + 1, total=steps, desc="CURRENT TRAINING STEP")
|
||||
|
||||
while step < steps:
|
||||
step += 1
|
||||
model.train()
|
||||
|
||||
total_loss = total_grad = total_grad_norm = 0
|
||||
|
||||
for i in range(self.data.gradient_accumulation_steps):
|
||||
# Get the next sequence from the dataset
|
||||
block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device)
|
||||
# Get the next sequences from the dataset
|
||||
block = torch.tensor(np.int32(self.get_batch(step, self.data.gradient_accumulation_steps))).to(model.transformer.wte.weight.device)
|
||||
|
||||
for sequence in tqdm(block, desc="GRADIENT ACCUMULATION", leave=False):
|
||||
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
|
||||
input_ids = block[:-1].unsqueeze(0).detach()
|
||||
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach()
|
||||
input_ids = sequence[:-1].unsqueeze(0).detach()
|
||||
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=sequence.device), sequence), dim=-1).unsqueeze(0).detach()
|
||||
|
||||
# Give the context to the model and compare the model's output logits with the labels to compute the loss
|
||||
logits = model(input_ids=input_ids, labels=input_ids).logits
|
||||
@ -579,12 +616,13 @@ class TrainerBase(abc.ABC):
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
step += 1
|
||||
|
||||
# Save checkpoint every few steps
|
||||
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||
save_mkusp(mean_loss, mean_grad_norm)
|
||||
|
||||
bar1.set_postfix({"loss": mean_loss, "grad_norm": mean_grad_norm, "learning_rate": lr})
|
||||
bar1.update()
|
||||
|
||||
|
||||
class BasicTrainer(TrainerBase):
|
||||
class TrainerData(TrainerBase.TrainerData):
|
||||
@ -652,12 +690,12 @@ class BasicTrainer(TrainerBase):
|
||||
)
|
||||
)
|
||||
sample_space = [
|
||||
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
|
||||
k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens
|
||||
]
|
||||
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
||||
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)))
|
||||
elif self.data.prompt_method == "tokens":
|
||||
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
|
||||
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)))
|
||||
self.raise_configuration_error(
|
||||
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user