From 160effb9ead86b184fc5ef3bb7075d2f1b9ec7a5 Mon Sep 17 00:00:00 2001 From: Henk Date: Sat, 15 Jul 2023 18:20:10 +0200 Subject: [PATCH] Add 4-bit BnB toggle --- .../generic_hf_torch/class.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 8f024ea1..dbe1d038 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -20,9 +20,29 @@ model_backend_name = "Huggingface" model_backend_type = "Huggingface" #This should be a generic name in case multiple model backends are compatible (think Hugging Face Custom and Basic Hugging Face) class model_backend(HFTorchInferenceModel): + def _initialize_model(self): return + def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}): + requested_parameters = super().get_requested_parameters(model_name, model_path, menu_path, parameters = {}) + requested_parameters.append({ + "uitype": "toggle", + "unit": "bool", + "label": "Use 4-bit", + "id": "use_4_bit", + "default": False, + "tooltip": "Whether or not to use BnB's 4-bit mode", + "menu_path": "Layers", + "extra_classes": "", + "refresh_model_inputs": False + }) + return requested_parameters + + def set_input_parameters(self, parameters): + super().set_input_parameters(parameters) + self.use_4_bit = parameters['use_4_bit'] + def _load(self, save_model: bool, initial_load: bool) -> None: utils.koboldai_vars.allowsp = True @@ -32,7 +52,7 @@ class model_backend(HFTorchInferenceModel): # behavior consistent with other loading methods - Henk717 # if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]: # utils.koboldai_vars.custmodpth = utils.koboldai_vars.model - + if self.model_name == "NeoCustom": self.model_name = os.path.basename(os.path.normpath(self.path)) utils.koboldai_vars.model = self.model_name @@ -50,6 +70,12 @@ class model_backend(HFTorchInferenceModel): tf_kwargs = { "low_cpu_mem_usage": True, } + + if self.use_4_bit: + self.lazy_load = False + tf_kwargs.update({ + "load_in_4bit": True, + }) if self.model_type == "gpt2": # We must disable low_cpu_mem_usage and if using a GPT-2 model