diff --git a/modeling/inference_models/generic_hf_torch/class.py b/modeling/inference_models/generic_hf_torch/class.py index 9a59650e..fe2308ac 100644 --- a/modeling/inference_models/generic_hf_torch/class.py +++ b/modeling/inference_models/generic_hf_torch/class.py @@ -78,7 +78,6 @@ class model_backend(HFTorchInferenceModel): } if self.use_4_bit: - self.lazy_load = False tf_kwargs.update({ "quantization_config":BitsAndBytesConfig( load_in_4bit=True, diff --git a/modeling/patches.py b/modeling/patches.py index 8ffc87a3..6e2168f2 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -181,7 +181,7 @@ class LazyloadPatches: is_quantized = is_quantized or load_in_8bit if is_quantized: - from .utils.bitsandbytes import set_module_8bit_tensor_to_device + from transformers.utils.bitsandbytes import set_module_quantized_tensor_to_device error_msgs = [] @@ -299,7 +299,7 @@ class LazyloadPatches: fp16_statistics = None if "SCB" not in param_name: - set_module_8bit_tensor_to_device( + set_module_quantized_tensor_to_device( model, param_name, param_device,