feature(load model): select control for quantization level

This commit is contained in:
Nick Perez
2023-07-19 07:58:12 -04:00
parent 0142913060
commit 9581e51476

View File

@@ -36,24 +36,14 @@ class model_backend(HFTorchInferenceModel):
else: else:
temp = {} temp = {}
requested_parameters.append({ requested_parameters.append({
"uitype": "toggle", "uitype": "dropdown",
"unit": "bool", "unit": "text",
"label": "Use 8-bit", "label": "Quantization",
"id": "use_8_bit", "id": "quantization",
"default": temp['use_8_bit'] if 'use_8_bit' in temp else False, "default": temp['quantization'] if 'quantization' in temp else 'none',
"tooltip": "Whether or not to use BnB's 8-bit mode", "tooltip": "Whether or not to use BnB's 4-bit or 8-bit mode",
"menu_path": "Layers",
"extra_classes": "",
"refresh_model_inputs": False
})
requested_parameters.append({
"uitype": "toggle",
"unit": "bool",
"label": "Use 4-bit",
"id": "use_4_bit",
"default": temp['use_4_bit'] if 'use_4_bit' in temp else False,
"tooltip": "Whether or not to use BnB's 4-bit mode",
"menu_path": "Layers", "menu_path": "Layers",
"children": [{'text': 'None', 'value':'none'},{'text': '4-bit', 'value': '4bit'}, {'text': '8-bit', 'value': '8bit'}],
"extra_classes": "", "extra_classes": "",
"refresh_model_inputs": False "refresh_model_inputs": False
}) })
@@ -63,8 +53,7 @@ class model_backend(HFTorchInferenceModel):
def set_input_parameters(self, parameters): def set_input_parameters(self, parameters):
super().set_input_parameters(parameters) super().set_input_parameters(parameters)
self.use_4_bit = parameters['use_4_bit'] if 'use_4_bit' in parameters else False self.quantization = parameters['quantization'] if 'quantization' in parameters else False
self.use_8_bit = parameters['use_8_bit'] if 'use_8_bit' in parameters else False
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
utils.koboldai_vars.allowsp = True utils.koboldai_vars.allowsp = True
@@ -94,7 +83,7 @@ class model_backend(HFTorchInferenceModel):
"low_cpu_mem_usage": True, "low_cpu_mem_usage": True,
} }
if self.use_8_bit: if self.quantization == "8bit":
tf_kwargs.update({ tf_kwargs.update({
"quantization_config":BitsAndBytesConfig( "quantization_config":BitsAndBytesConfig(
load_in_8bit=True, load_in_8bit=True,
@@ -102,7 +91,7 @@ class model_backend(HFTorchInferenceModel):
), ),
}) })
if self.use_4_bit or utils.koboldai_vars.colab_arg: if self.quantization == "4bit" or utils.koboldai_vars.colab_arg:
tf_kwargs.update({ tf_kwargs.update({
"quantization_config":BitsAndBytesConfig( "quantization_config":BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@@ -317,8 +306,7 @@ class model_backend(HFTorchInferenceModel):
"disk_layers": self.disk_layers "disk_layers": self.disk_layers
if "disk_layers" in vars(self) if "disk_layers" in vars(self)
else 0, else 0,
"use_4_bit": self.use_4_bit, "quantization": self.quantization,
"use_8_bit": self.use_8_bit,
}, },
f, f,
indent="", indent="",