8 bit debug

This commit is contained in:
ebolam
2022-12-01 10:43:59 -05:00
parent fcdfce0373
commit 2505d09181

View File

@@ -316,7 +316,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C
import bitsandbytes as bnb
def linear_init(self, *args, device=None, **kwargs):
if linear_init.nested_flag:
if linear_init.nested_flag or not linear_init.bit_8_available:
return old_linear_init(self, *args, device=device, **kwargs)
linear_init.nested_flag = True
try:
@@ -325,7 +325,8 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C
return bnb.nn.Linear8bitLt.__init__(self, *args, has_fp16_weights=False, threshold=6.0, **kwargs)
finally:
linear_init.nested_flag = False
linear_init.nested_flag = bit_8_available
linear_init.nested_flag = False
linear_init.bit_8_available = bit_8_available
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)