mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix remaining problems in prompt_tuner.py
This commit is contained in:
@@ -15,11 +15,12 @@ import base64
|
|||||||
import pickle
|
import pickle
|
||||||
import hashlib
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
|
from tqdm.auto import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Embedding, CrossEntropyLoss
|
from torch.nn import Embedding, CrossEntropyLoss
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast
|
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoConfig
|
||||||
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
from mkultra.tuning import GPTPromptTuningMixin, GPTNeoPromptTuningLM
|
||||||
from mkultra.soft_prompt import SoftPrompt
|
from mkultra.soft_prompt import SoftPrompt
|
||||||
from typing import List, Optional, TextIO, Union
|
from typing import List, Optional, TextIO, Union
|
||||||
@@ -27,6 +28,18 @@ from typing import List, Optional, TextIO, Union
|
|||||||
|
|
||||||
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
_PromptTuningPreTrainedModel = Union["UniversalPromptTuningMixin", GPTPromptTuningMixin, transformers.PreTrainedModel]
|
||||||
|
|
||||||
|
class _WTEDummy:
|
||||||
|
def __init__(self, model: transformers.PreTrainedModel):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wte(self: "_WTEDummy"):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
@wte.setter
|
||||||
|
def wte(self: "_WTEDummy", v):
|
||||||
|
self.model.set_input_embeddings(v)
|
||||||
|
|
||||||
class _WTEMixin:
|
class _WTEMixin:
|
||||||
@property
|
@property
|
||||||
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
|
def wte(self: Union["_WTEMixin", transformers.PreTrainedModel]):
|
||||||
@@ -43,7 +56,7 @@ class UniversalPromptTuningMixin:
|
|||||||
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
model: _PromptTuningPreTrainedModel = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if not hasattr(model, "transformer"):
|
if not hasattr(model, "transformer"):
|
||||||
model.transformer = _WTEMixin()
|
model.transformer = _WTEDummy(model)
|
||||||
elif not hasattr(model.transformer, "wte"):
|
elif not hasattr(model.transformer, "wte"):
|
||||||
assert isinstance(model.transformer, type)
|
assert isinstance(model.transformer, type)
|
||||||
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
|
model.transformer.__class__ = type("_UniversalPromptTuning" + model.transformer.__class__.__name__, (_WTEMixin, model.transformer.__class__), {})
|
||||||
@@ -248,6 +261,26 @@ class TrainerBase(abc.ABC):
|
|||||||
raise ConfigurationError(msg, **kwargs)
|
raise ConfigurationError(msg, **kwargs)
|
||||||
|
|
||||||
def get_hf_checkpoint_metadata(self) -> bool:
|
def get_hf_checkpoint_metadata(self) -> bool:
|
||||||
|
REVISION = None
|
||||||
|
params = {}
|
||||||
|
if(os.path.isdir(self.data.ckpt_path)):
|
||||||
|
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||||
|
elif(os.path.isdir("models/{}".format(self.data.ckpt_path.replace('/', '_')))):
|
||||||
|
model_config = AutoConfig.from_pretrained("models/{}".format(self.data.ckpt_path.replace('/', '_')), revision=REVISION, cache_dir="cache")
|
||||||
|
else:
|
||||||
|
model_config = AutoConfig.from_pretrained(self.data.ckpt_path, revision=REVISION, cache_dir="cache")
|
||||||
|
params["tokenizer_id"] = self.data.ckpt_path
|
||||||
|
tokenizer = get_tokenizer(self.data.ckpt_path)
|
||||||
|
params["newlinemode"] = params.get(
|
||||||
|
"newlinemode", "s" if model_config.model_type == "xglm" else "n"
|
||||||
|
)
|
||||||
|
params["max_batch_size"] = 2048
|
||||||
|
with tokenizer._kai_no_prefix():
|
||||||
|
params["eos_token"] = (
|
||||||
|
[50259, 50259] if model_config.model_type == "xglm" and model_config.eos_token_id == 50259 else tokenizer.encode(model_config.eos_token_id)
|
||||||
|
)
|
||||||
|
params["seq"] = 2048
|
||||||
|
self.data.params = params
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
def get_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
|
||||||
@@ -445,6 +478,7 @@ class TrainerBase(abc.ABC):
|
|||||||
self.raise_configuration_error(
|
self.raise_configuration_error(
|
||||||
"You have not set a soft prompt size.", code=6
|
"You have not set a soft prompt size.", code=6
|
||||||
)
|
)
|
||||||
|
step = 0
|
||||||
else:
|
else:
|
||||||
# If we're resuming a soft-tuning session, the soft prompt tensor is
|
# If we're resuming a soft-tuning session, the soft prompt tensor is
|
||||||
# already in the save file and we just have to decode it.
|
# already in the save file and we just have to decode it.
|
||||||
@@ -502,7 +536,7 @@ class TrainerBase(abc.ABC):
|
|||||||
if beta1 == 0.0:
|
if beta1 == 0.0:
|
||||||
beta1 = None
|
beta1 = None
|
||||||
optimizer = transformers.Adafactor(
|
optimizer = transformers.Adafactor(
|
||||||
params=model.get_soft_params(),
|
params=(model.get_soft_params(),),
|
||||||
scale_parameter=False,
|
scale_parameter=False,
|
||||||
relative_step=False,
|
relative_step=False,
|
||||||
warmup_init=False,
|
warmup_init=False,
|
||||||
@@ -540,19 +574,22 @@ class TrainerBase(abc.ABC):
|
|||||||
f,
|
f,
|
||||||
)
|
)
|
||||||
self.save_data()
|
self.save_data()
|
||||||
|
|
||||||
|
bar1 = tqdm(initial=step + 1, total=steps, desc="CURRENT TRAINING STEP")
|
||||||
|
|
||||||
while step < steps:
|
while step < steps:
|
||||||
|
step += 1
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
total_loss = total_grad = total_grad_norm = 0
|
total_loss = total_grad = total_grad_norm = 0
|
||||||
|
|
||||||
for i in range(self.data.gradient_accumulation_steps):
|
# Get the next sequences from the dataset
|
||||||
# Get the next sequence from the dataset
|
block = torch.tensor(np.int32(self.get_batch(step, self.data.gradient_accumulation_steps))).to(model.transformer.wte.weight.device)
|
||||||
block = self.get_batch(step, self.data.gradient_accumulation_steps).to(model.transformer.wte.weight.device)
|
|
||||||
|
|
||||||
|
for sequence in tqdm(block, desc="GRADIENT ACCUMULATION", leave=False):
|
||||||
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
|
# input_ids is the context to the model (without the soft prompt) and labels is what we expect the model to generate (the -100s represent soft prompt tokens for which loss is not calculated)
|
||||||
input_ids = block[:-1].unsqueeze(0).detach()
|
input_ids = sequence[:-1].unsqueeze(0).detach()
|
||||||
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=block.device), block)).unsqueeze(0).cuda().detach()
|
labels = torch.cat((torch.full((model.get_soft_params().size(0) - 1,), -100, device=sequence.device), sequence), dim=-1).unsqueeze(0).detach()
|
||||||
|
|
||||||
# Give the context to the model and compare the model's output logits with the labels to compute the loss
|
# Give the context to the model and compare the model's output logits with the labels to compute the loss
|
||||||
logits = model(input_ids=input_ids, labels=input_ids).logits
|
logits = model(input_ids=input_ids, labels=input_ids).logits
|
||||||
@@ -579,12 +616,13 @@ class TrainerBase(abc.ABC):
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
step += 1
|
|
||||||
|
|
||||||
# Save checkpoint every few steps
|
# Save checkpoint every few steps
|
||||||
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
if step == 1 or step % self.data.stparams["save_every"] == 0:
|
||||||
save_mkusp(mean_loss, mean_grad_norm)
|
save_mkusp(mean_loss, mean_grad_norm)
|
||||||
|
|
||||||
|
bar1.set_postfix({"loss": mean_loss, "grad_norm": mean_grad_norm, "learning_rate": lr})
|
||||||
|
bar1.update()
|
||||||
|
|
||||||
|
|
||||||
class BasicTrainer(TrainerBase):
|
class BasicTrainer(TrainerBase):
|
||||||
class TrainerData(TrainerBase.TrainerData):
|
class TrainerData(TrainerBase.TrainerData):
|
||||||
@@ -652,12 +690,12 @@ class BasicTrainer(TrainerBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
sample_space = [
|
sample_space = [
|
||||||
k for k in range(self.data.params["n_vocab"]) if k not in special_tokens
|
k for k in range(model.get_input_embeddings().weight.shape[-2]) if k not in special_tokens
|
||||||
]
|
]
|
||||||
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
sample = rng.choice(sample_space, self.data.soft_in_dim, False)
|
||||||
return model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32))
|
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(sample, dtype=torch.int32)))
|
||||||
elif self.data.prompt_method == "tokens":
|
elif self.data.prompt_method == "tokens":
|
||||||
return model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32))
|
return SoftPrompt.from_inputs_embeds(model.get_input_embeddings()(torch.tensor(self.data.initial_softprompt, dtype=torch.int32)))
|
||||||
self.raise_configuration_error(
|
self.raise_configuration_error(
|
||||||
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
f"Unknown prompt method {repr(self.data.prompt_method)}", code=104
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user