mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Maybe works now...?
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from ctypes import Union
|
||||
import requests
|
||||
from typing import Iterable, List
|
||||
from typing import Iterable, List, Optional
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
@@ -148,6 +149,27 @@ def patch_transformers_for_lazyload() -> None:
|
||||
"""
|
||||
import torch
|
||||
import accelerate
|
||||
|
||||
# _old_set_module_tensor_to_device = (
|
||||
# accelerate.utils.modeling.set_module_tensor_to_device
|
||||
# )
|
||||
|
||||
# def _set_module_tensor_to_device(
|
||||
# module: torch.nn.Module,
|
||||
# tensor_name: str,
|
||||
# device: Union[int, str, torch.device],
|
||||
# value: Optional[torch.Tensor] = None,
|
||||
# dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
# ):
|
||||
# if isinstance(value, LazyTensor):
|
||||
# value = value.materialize()
|
||||
# print("HEY!", dtype)
|
||||
# return _old_set_module_tensor_to_device(
|
||||
# module, tensor_name, device, value, dtype
|
||||
# )
|
||||
|
||||
# accelerate.utils.modeling.set_module_tensor_to_device = _set_module_tensor_to_device
|
||||
|
||||
from accelerate.utils.modeling import named_module_tensors
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
|
||||
@@ -205,7 +227,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
|
||||
# BEGIN PATCH
|
||||
# TODO: Based on config
|
||||
dtype = torch.float16
|
||||
# dtype = torch.float16
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
@@ -316,60 +338,62 @@ def patch_transformers_for_lazyload() -> None:
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
# Patch AlignDevicesHook to hack around OPT lm_head
|
||||
HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
|
||||
# # Patch AlignDevicesHook to hack around OPT lm_head
|
||||
# HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
|
||||
|
||||
def _init_hook(self, module):
|
||||
if not self.offload and self.execution_device is not None:
|
||||
# BEGIN PATCH
|
||||
for name, tensor in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
):
|
||||
try:
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
except ValueError:
|
||||
# ValueError: weight is on the meta device, we need a `value` to put in on 0.
|
||||
# bleuuuuuuuuuuuuuuuhhh
|
||||
if name in HACK_ZERO_ON_FAIL_TENSORS:
|
||||
logger.warning(f"Couldn't find value for weight {name}, zeroing.")
|
||||
set_module_tensor_to_device(
|
||||
module,
|
||||
name,
|
||||
self.execution_device,
|
||||
value=torch.zeros(tensor.shape),
|
||||
)
|
||||
# END PATCH
|
||||
elif self.offload:
|
||||
self.original_devices = {
|
||||
name: param.device
|
||||
for name, param in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
)
|
||||
}
|
||||
# def _init_hook(self, module):
|
||||
# if not self.offload and self.execution_device is not None:
|
||||
# # BEGIN PATCH
|
||||
# for name, tensor in named_module_tensors(
|
||||
# module, recurse=self.place_submodules
|
||||
# ):
|
||||
# try:
|
||||
# set_module_tensor_to_device(module, name, self.execution_device)
|
||||
# except ValueError:
|
||||
# # ValueError: weight is on the meta device, we need a `value` to put in on 0.
|
||||
# # bleuuuuuuuuuuuuuuuhhh
|
||||
# if name in HACK_ZERO_ON_FAIL_TENSORS:
|
||||
# logger.warning(
|
||||
# f"Couldn't find value for weight {name}, zeroing."
|
||||
# )
|
||||
# set_module_tensor_to_device(
|
||||
# module,
|
||||
# name,
|
||||
# self.execution_device,
|
||||
# value=torch.zeros(tensor.shape),
|
||||
# )
|
||||
# # END PATCH
|
||||
# elif self.offload:
|
||||
# self.original_devices = {
|
||||
# name: param.device
|
||||
# for name, param in named_module_tensors(
|
||||
# module, recurse=self.place_submodules
|
||||
# )
|
||||
# }
|
||||
|
||||
if self.weights_map is None:
|
||||
self.weights_map = {
|
||||
name: param.to("cpu")
|
||||
for name, param in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
)
|
||||
}
|
||||
# if self.weights_map is None:
|
||||
# self.weights_map = {
|
||||
# name: param.to("cpu")
|
||||
# for name, param in named_module_tensors(
|
||||
# module,
|
||||
# include_buffers=self.offload_buffers,
|
||||
# recurse=self.place_submodules,
|
||||
# )
|
||||
# }
|
||||
|
||||
for name, _ in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
# for name, _ in named_module_tensors(
|
||||
# module,
|
||||
# include_buffers=self.offload_buffers,
|
||||
# recurse=self.place_submodules,
|
||||
# ):
|
||||
# set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
return module
|
||||
# if not self.offload_buffers and self.execution_device is not None:
|
||||
# for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
# set_module_tensor_to_device(module, name, self.execution_device)
|
||||
# return module
|
||||
|
||||
accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
|
||||
# accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
|
||||
|
||||
|
||||
def patch_transformers() -> None:
|
||||
|
Reference in New Issue
Block a user