8 bit debug

This commit is contained in:
ebolam
2022-12-01 10:46:13 -05:00
parent 2505d09181
commit 76a0bb71f0

View File

@@ -317,7 +317,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C
def linear_init(self, *args, device=None, **kwargs):
if linear_init.nested_flag or not linear_init.bit_8_available:
return old_linear_init(self, *args, device=device, **kwargs)
return old_linear_init(self, *args, device="meta", **kwargs)
linear_init.nested_flag = True
try:
self.__class__ = bnb.nn.Linear8bitLt