Maybe works now...?

This commit is contained in:
somebody
2023-05-31 14:31:08 -05:00
parent d0d215bb37
commit 24b0b32829
3 changed files with 239 additions and 547 deletions

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import copy
from ctypes import Union
import requests
from typing import Iterable, List
from typing import Iterable, List, Optional
from tqdm.auto import tqdm
import transformers
@@ -148,6 +149,27 @@ def patch_transformers_for_lazyload() -> None:
"""
import torch
import accelerate
# _old_set_module_tensor_to_device = (
# accelerate.utils.modeling.set_module_tensor_to_device
# )
# def _set_module_tensor_to_device(
# module: torch.nn.Module,
# tensor_name: str,
# device: Union[int, str, torch.device],
# value: Optional[torch.Tensor] = None,
# dtype: Optional[Union[str, torch.dtype]] = None,
# ):
# if isinstance(value, LazyTensor):
# value = value.materialize()
# print("HEY!", dtype)
# return _old_set_module_tensor_to_device(
# module, tensor_name, device, value, dtype
# )
# accelerate.utils.modeling.set_module_tensor_to_device = _set_module_tensor_to_device
from accelerate.utils.modeling import named_module_tensors
from accelerate.utils import set_module_tensor_to_device, offload_weight
@@ -205,7 +227,7 @@ def patch_transformers_for_lazyload() -> None:
# BEGIN PATCH
# TODO: Based on config
dtype = torch.float16
# dtype = torch.float16
for param_name, param in sorted(
state_dict.items(),
@@ -316,60 +338,62 @@ def patch_transformers_for_lazyload() -> None:
_load_state_dict_into_meta_model
)
# Patch AlignDevicesHook to hack around OPT lm_head
HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
# # Patch AlignDevicesHook to hack around OPT lm_head
# HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
def _init_hook(self, module):
if not self.offload and self.execution_device is not None:
# BEGIN PATCH
for name, tensor in named_module_tensors(
module, recurse=self.place_submodules
):
try:
set_module_tensor_to_device(module, name, self.execution_device)
except ValueError:
# ValueError: weight is on the meta device, we need a `value` to put in on 0.
# bleuuuuuuuuuuuuuuuhhh
if name in HACK_ZERO_ON_FAIL_TENSORS:
logger.warning(f"Couldn't find value for weight {name}, zeroing.")
set_module_tensor_to_device(
module,
name,
self.execution_device,
value=torch.zeros(tensor.shape),
)
# END PATCH
elif self.offload:
self.original_devices = {
name: param.device
for name, param in named_module_tensors(
module, recurse=self.place_submodules
)
}
# def _init_hook(self, module):
# if not self.offload and self.execution_device is not None:
# # BEGIN PATCH
# for name, tensor in named_module_tensors(
# module, recurse=self.place_submodules
# ):
# try:
# set_module_tensor_to_device(module, name, self.execution_device)
# except ValueError:
# # ValueError: weight is on the meta device, we need a `value` to put in on 0.
# # bleuuuuuuuuuuuuuuuhhh
# if name in HACK_ZERO_ON_FAIL_TENSORS:
# logger.warning(
# f"Couldn't find value for weight {name}, zeroing."
# )
# set_module_tensor_to_device(
# module,
# name,
# self.execution_device,
# value=torch.zeros(tensor.shape),
# )
# # END PATCH
# elif self.offload:
# self.original_devices = {
# name: param.device
# for name, param in named_module_tensors(
# module, recurse=self.place_submodules
# )
# }
if self.weights_map is None:
self.weights_map = {
name: param.to("cpu")
for name, param in named_module_tensors(
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
)
}
# if self.weights_map is None:
# self.weights_map = {
# name: param.to("cpu")
# for name, param in named_module_tensors(
# module,
# include_buffers=self.offload_buffers,
# recurse=self.place_submodules,
# )
# }
for name, _ in named_module_tensors(
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
):
set_module_tensor_to_device(module, name, "meta")
# for name, _ in named_module_tensors(
# module,
# include_buffers=self.offload_buffers,
# recurse=self.place_submodules,
# ):
# set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
return module
# if not self.offload_buffers and self.execution_device is not None:
# for name, _ in module.named_buffers(recurse=self.place_submodules):
# set_module_tensor_to_device(module, name, self.execution_device)
# return module
accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
# accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
def patch_transformers() -> None: