diff --git a/torch_lazy_loader.py b/torch_lazy_loader.py index ec99df11..474693a7 100644 --- a/torch_lazy_loader.py +++ b/torch_lazy_loader.py @@ -317,7 +317,7 @@ def use_lazy_torch_load(bit_8_available=False, enable=True, callback: Optional[C def linear_init(self, *args, device=None, **kwargs): if linear_init.nested_flag or not linear_init.bit_8_available: - return old_linear_init(self, *args, device=device, **kwargs) + return old_linear_init(self, *args, device="meta", **kwargs) linear_init.nested_flag = True try: self.__class__ = bnb.nn.Linear8bitLt