From 01429130602841d0d85ab27f6b38a4f63be0372c Mon Sep 17 00:00:00 2001 From: Nick Perez Date: Tue, 18 Jul 2023 23:29:38 -0400 Subject: [PATCH] 8 bit toggle, fix for broken toggle values --- .../generic_hf_torch/class.py | 21 +++++++++++++++++++ static/koboldai.js | 4 +--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 0bb954e3..49c6ca33 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -35,6 +35,17 @@ class model_backend(HFTorchInferenceModel): temp = json.load(f) else: temp = {} + requested_parameters.append({ + "uitype": "toggle", + "unit": "bool", + "label": "Use 8-bit", + "id": "use_8_bit", + "default": temp['use_8_bit'] if 'use_8_bit' in temp else False, + "tooltip": "Whether or not to use BnB's 8-bit mode", + "menu_path": "Layers", + "extra_classes": "", + "refresh_model_inputs": False + }) requested_parameters.append({ "uitype": "toggle", "unit": "bool", @@ -53,6 +64,7 @@ class model_backend(HFTorchInferenceModel): def set_input_parameters(self, parameters): super().set_input_parameters(parameters) self.use_4_bit = parameters['use_4_bit'] if 'use_4_bit' in parameters else False + self.use_8_bit = parameters['use_8_bit'] if 'use_8_bit' in parameters else False def _load(self, save_model: bool, initial_load: bool) -> None: utils.koboldai_vars.allowsp = True @@ -82,6 +94,14 @@ class model_backend(HFTorchInferenceModel): "low_cpu_mem_usage": True, } + if self.use_8_bit: + tf_kwargs.update({ + "quantization_config":BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=True + ), + }) + if self.use_4_bit or utils.koboldai_vars.colab_arg: tf_kwargs.update({ "quantization_config":BitsAndBytesConfig( @@ -298,6 +318,7 @@ class model_backend(HFTorchInferenceModel): if "disk_layers" in vars(self) else 0, "use_4_bit": self.use_4_bit, + "use_8_bit": self.use_8_bit, }, f, indent="", diff --git a/static/koboldai.js b/static/koboldai.js index 94ac6ce4..8b70dd6a 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -2011,7 +2011,7 @@ function load_model() { data = {} if (settings_area) { for (const element of settings_area.querySelectorAll(".model_settings_input:not(.hidden)")) { - var element_data = element.value; + var element_data = element.getAttribute("data_type") === "bool" ? element.checked : element.value; if ((element.tagName == "SELECT") && (element.multiple)) { element_data = []; for (var i=0, iLen=element.options.length; i