From 2505d091811a78073246fbc71d0149b7d98e5d0e Mon Sep 17 00:00:00 2001 From: ebolam Date: Thu, 1 Dec 2022 10:43:59 -0500 Subject: [PATCH] 8 bit debug --- torch_lazy_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index dc3edfb3..ec99df11 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -316,7 +316,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C import bitsandbytes as bnb def linear_init(self, *args, device=None, **kwargs): - if linear_init.nested_flag: + if linear_init.nested_flag or not linear_init.bit_8_available: return old_linear_init(self, *args, device=device, **kwargs) linear_init.nested_flag = True try: @@ -325,7 +325,8 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C return bnb.nn.Linear8bitLt.__init__(self, *args, has_fp16_weights=False, threshold=6.0, **kwargs) finally: linear_init.nested_flag = False - linear_init.nested_flag = bit_8_available + linear_init.nested_flag = False + linear_init.bit_8_available = bit_8_available def embedding_init(self, *args, device=None, **kwargs): return old_embedding_init(self, *args, device="meta", **kwargs)