mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Restore --peft support
This commit is contained in:
@@ -262,6 +262,40 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||||
use_core_manipulations.sample = new_sample
|
use_core_manipulations.sample = new_sample
|
||||||
|
|
||||||
|
# PEFT Loading. This MUST be done after all save_pretrained calls are
|
||||||
|
# finished on the main model.
|
||||||
|
if utils.args.peft:
|
||||||
|
from peft import PeftModel, PeftConfig
|
||||||
|
local_peft_dir = os.path.join(m_self.get_local_model_path(), "peft")
|
||||||
|
|
||||||
|
# Make PEFT dir if it doesn't exist
|
||||||
|
try:
|
||||||
|
os.makedirs(local_peft_dir)
|
||||||
|
except FileExistsError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
peft_local_path = os.path.join(local_peft_dir, utils.args.peft.replace("/", "_"))
|
||||||
|
logger.debug(f"Loading PEFT '{utils.args.peft}', possible local path is '{peft_local_path}'.")
|
||||||
|
|
||||||
|
peft_installed_locally = True
|
||||||
|
possible_peft_locations = [peft_local_path, utils.args.peft]
|
||||||
|
|
||||||
|
for i, location in enumerate(possible_peft_locations):
|
||||||
|
try:
|
||||||
|
m_self.model = PeftModel.from_pretrained(m_self.model, location)
|
||||||
|
logger.debug(f"Loaded PEFT at '{location}'")
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
peft_installed_locally = False
|
||||||
|
if i == len(possible_peft_locations) - 1:
|
||||||
|
raise RuntimeError(f"Unable to load PeftModel for given name '{utils.args.peft}'. Does it exist?")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError("Error while loading PeftModel. Are you using the correct model?")
|
||||||
|
|
||||||
|
if not peft_installed_locally:
|
||||||
|
logger.debug(f"PEFT not saved to models folder; saving to '{peft_local_path}'")
|
||||||
|
m_self.model.save_pretrained(peft_local_path)
|
||||||
|
|
||||||
return super()._post_load()
|
return super()._post_load()
|
||||||
|
|
||||||
def _raw_generate(
|
def _raw_generate(
|
||||||
|
Reference in New Issue
Block a user