From 16240878bc4dc8dc26f28b58d4f19cde79ea9e79 Mon Sep 17 00:00:00 2001 From: Henk Date: Tue, 4 Jul 2023 20:42:29 +0200 Subject: [PATCH] Restore --peft support --- modeling/inference_models/hf_torch.py | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index fb52bac1..84d3447e 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -262,6 +262,40 @@ class HFTorchInferenceModel(HFInferenceModel): new_sample.old_sample = transformers.GenerationMixin.sample use_core_manipulations.sample = new_sample + # PEFT Loading. This MUST be done after all save_pretrained calls are + # finished on the main model. + if utils.args.peft: + from peft import PeftModel, PeftConfig + local_peft_dir = os.path.join(m_self.get_local_model_path(), "peft") + + # Make PEFT dir if it doesn't exist + try: + os.makedirs(local_peft_dir) + except FileExistsError: + pass + + peft_local_path = os.path.join(local_peft_dir, utils.args.peft.replace("/", "_")) + logger.debug(f"Loading PEFT '{utils.args.peft}', possible local path is '{peft_local_path}'.") + + peft_installed_locally = True + possible_peft_locations = [peft_local_path, utils.args.peft] + + for i, location in enumerate(possible_peft_locations): + try: + m_self.model = PeftModel.from_pretrained(m_self.model, location) + logger.debug(f"Loaded PEFT at '{location}'") + break + except ValueError: + peft_installed_locally = False + if i == len(possible_peft_locations) - 1: + raise RuntimeError(f"Unable to load PeftModel for given name '{utils.args.peft}'. Does it exist?") + except RuntimeError: + raise RuntimeError("Error while loading PeftModel. Are you using the correct model?") + + if not peft_installed_locally: + logger.debug(f"PEFT not saved to models folder; saving to '{peft_local_path}'") + m_self.model.save_pretrained(peft_local_path) + return super()._post_load() def _raw_generate(