diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 40006dab..b51d8f66 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -6,7 +6,7 @@ import torch import shutil from typing import Union -from transformers import AutoModelForCausalLM, GPTNeoForCausalLM, GPT2LMHeadModel +from transformers import AutoModelForCausalLM, GPTNeoForCausalLM, GPT2LMHeadModel, BitsAndBytesConfig import utils import modeling.lazy_loader as lazy_loader @@ -81,6 +81,12 @@ class model_backend(HFTorchInferenceModel): self.lazy_load = False tf_kwargs.update({ "load_in_4bit": True, + "quantization_config":BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ), }) if self.model_type == "gpt2":