mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
OPT hack
This commit is contained in:
@@ -13,6 +13,7 @@ from transformers import (
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
|
||||
def patch_transformers_download():
|
||||
@@ -127,8 +128,27 @@ def patch_transformers_generation() -> None:
|
||||
|
||||
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
"""
|
||||
Most of the code is modified code from the Accelerate and Transformers
|
||||
projects, made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import torch
|
||||
import inspect
|
||||
import accelerate
|
||||
from accelerate.utils.modeling import named_module_tensors
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
@@ -183,10 +203,9 @@ def patch_transformers_for_lazyload() -> None:
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# BEGIN PATCH
|
||||
# BEGIN PATCH
|
||||
# TODO: Based on config
|
||||
dtype = torch.float16
|
||||
set_module_kwargs = {"dtype": dtype}
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
@@ -202,7 +221,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
if isinstance(param, LazyTensor):
|
||||
# Should always be true
|
||||
param = param.materialize()
|
||||
# END PATCH
|
||||
# END PATCH
|
||||
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if (
|
||||
@@ -228,13 +247,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
# TODO: @sgugger replace this check with version check at the next `accelerate` release
|
||||
if "dtype" in list(
|
||||
inspect.signature(set_module_tensor_to_device).parameters
|
||||
):
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
@@ -250,8 +262,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
if old_param is not None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
set_module_kwargs["value"] = param
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
@@ -263,6 +273,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(
|
||||
@@ -275,7 +286,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
elif not load_in_8bit:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model, tensor_name=param_name, device=param_device, **set_module_kwargs
|
||||
model,
|
||||
tensor_name=param_name,
|
||||
device=param_device,
|
||||
value=param,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
@@ -301,6 +315,73 @@ def patch_transformers_for_lazyload() -> None:
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
# Patch AlignDevicesHook to hack around OPT lm_head
|
||||
HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
|
||||
|
||||
def _recursed_key_value(module, key):
|
||||
"""Gets a tensor from a recursive key (ie with .s)"""
|
||||
if "." in key:
|
||||
splits = key.split(".")
|
||||
for split in splits[:-1]:
|
||||
new_module = getattr(module, split)
|
||||
if new_module is None:
|
||||
raise ValueError(f"{module} has no attribute {split}.")
|
||||
module = new_module
|
||||
key = splits[-1]
|
||||
return getattr(module, key)
|
||||
|
||||
def _init_hook(self, module):
|
||||
if not self.offload and self.execution_device is not None:
|
||||
# BEGIN PATCH
|
||||
for name, tensor in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
):
|
||||
try:
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
except ValueError:
|
||||
# ValueError: weight is on the meta device, we need a `value` to put in on 0.
|
||||
# bleuuuuuuuuuuuuuuuhhh
|
||||
if name in HACK_ZERO_ON_FAIL_TENSORS:
|
||||
logger.warning(f"Couldn't find value for weight {name}, zeroing.")
|
||||
set_module_tensor_to_device(
|
||||
module,
|
||||
name,
|
||||
self.execution_device,
|
||||
value=torch.zeros(tensor.shape),
|
||||
)
|
||||
# END PATCH
|
||||
elif self.offload:
|
||||
self.original_devices = {
|
||||
name: param.device
|
||||
for name, param in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
)
|
||||
}
|
||||
|
||||
if self.weights_map is None:
|
||||
self.weights_map = {
|
||||
name: param.to("cpu")
|
||||
for name, param in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
)
|
||||
}
|
||||
|
||||
for name, _ in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
return module
|
||||
|
||||
accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
|
||||
|
||||
|
||||
def patch_transformers() -> None:
|
||||
patch_transformers_download()
|
||||
|
Reference in New Issue
Block a user