8 bit debug

This commit is contained in:
ebolam
2022-12-01 10:41:13 -05:00
parent 94953def32
commit fcdfce0373

View File

@@ -316,7 +316,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C
import bitsandbytes as bnb
def linear_init(self, *args, device=None, **kwargs):
if linear_init.nested_flag or not bit_8_available:
if linear_init.nested_flag:
return old_linear_init(self, *args, device=device, **kwargs)
linear_init.nested_flag = True
try:
@@ -325,7 +325,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C
return bnb.nn.Linear8bitLt.__init__(self, *args, has_fp16_weights=False, threshold=6.0, **kwargs)
finally:
linear_init.nested_flag = False
linear_init.nested_flag = False
linear_init.nested_flag = bit_8_available
def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)