Basic PEFT support

This commit is contained in:
somebody
2023-05-03 18:51:01 -05:00
parent a9ef475142
commit 35b56117e6
6 changed files with 38 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from transformers import (
AutoModelForCausalLM,
LogitsProcessorList,
)
from peft import PeftModel, PeftConfig
import utils
import modeling.lazy_loader as lazy_loader
@@ -211,6 +212,31 @@ class HFTorchInferenceModel(HFInferenceModel):
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# PEFT Loading. This MUST be done after all save_pretrained calls are
# finished on the main model.
if utils.args.peft:
peft_local_path = os.path.join("models/peft", utils.args.peft.replace("/", "_"))
logger.debug(f"Loading PEFT '{utils.args.peft}', possible local path is '{peft_local_path}'.")
peft_installed_locally = True
possible_peft_locations = [peft_local_path, utils.args.peft]
for i, location in enumerate(possible_peft_locations):
try:
m_self.model = PeftModel.from_pretrained(m_self.model, location)
logger.debug(f"Loaded PEFT at '{location}'")
break
except ValueError:
peft_installed_locally = False
if i == len(possible_peft_locations) - 1:
raise RuntimeError(f"Unable to load PeftModel for given name '{utils.args.peft}'. Does it exist?")
except RuntimeError:
raise RuntimeError("Error while loading PeftModel. Are you using the correct model?")
if not peft_installed_locally:
logger.debug(f"PEFT not saved to models folder; saving to '{peft_local_path}'")
m_self.model.save_pretrained(peft_local_path)
return super()._post_load()
def _raw_generate(
@@ -238,8 +264,13 @@ class HFTorchInferenceModel(HFInferenceModel):
with torch.no_grad():
start_time = time.time()
# HEED & BEWARE: All arguments passed to self.model.generate MUST be
# kwargs; see https://github.com/huggingface/peft/issues/232. If they
# aren't, PeftModel will EXPLODE!!!! But nothing will happen without
# a PEFT loaded so it's sneaky.
genout = self.model.generate(
gen_in,
input_ids=gen_in,
do_sample=True,
max_length=min(
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length