Restore --peft support

This commit is contained in:
Henk
2023-07-04 20:42:29 +02:00
parent 94c62a4f90
commit 16240878bc

View File

@@ -262,6 +262,40 @@ class HFTorchInferenceModel(HFInferenceModel):
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# PEFT Loading. This MUST be done after all save_pretrained calls are
# finished on the main model.
if utils.args.peft:
from peft import PeftModel, PeftConfig
local_peft_dir = os.path.join(m_self.get_local_model_path(), "peft")
# Make PEFT dir if it doesn't exist
try:
os.makedirs(local_peft_dir)
except FileExistsError:
pass
peft_local_path = os.path.join(local_peft_dir, utils.args.peft.replace("/", "_"))
logger.debug(f"Loading PEFT '{utils.args.peft}', possible local path is '{peft_local_path}'.")
peft_installed_locally = True
possible_peft_locations = [peft_local_path, utils.args.peft]
for i, location in enumerate(possible_peft_locations):
try:
m_self.model = PeftModel.from_pretrained(m_self.model, location)
logger.debug(f"Loaded PEFT at '{location}'")
break
except ValueError:
peft_installed_locally = False
if i == len(possible_peft_locations) - 1:
raise RuntimeError(f"Unable to load PeftModel for given name '{utils.args.peft}'. Does it exist?")
except RuntimeError:
raise RuntimeError("Error while loading PeftModel. Are you using the correct model?")
if not peft_installed_locally:
logger.debug(f"PEFT not saved to models folder; saving to '{peft_local_path}'")
m_self.model.save_pretrained(peft_local_path)
return super()._post_load()
def _raw_generate(