diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index dc3edfb3..ec99df11 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -316,7 +316,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C import bitsandbytes as bnb def linear_init(self, *args, device=None, **kwargs): - if linear_init.nested_flag: + if linear_init.nested_flag or not linear_init.bit_8_available: return old_linear_init(self, *args, device=device, **kwargs) linear_init.nested_flag = True try: @@ -325,7 +325,8 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C return bnb.nn.Linear8bitLt.__init__(self, *args, has_fp16_weights=False, threshold=6.0, **kwargs) finally: linear_init.nested_flag = False - linear_init.nested_flag = bit_8_available + linear_init.nested_flag = False + linear_init.bit_8_available = bit_8_available def embedding_init(self, *args, device=None, **kwargs): return old_embedding_init(self, *args, device="meta", **kwargs)