mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
OPT hack
This commit is contained in:
@@ -27,12 +27,12 @@ from ansi2html import Ansi2HTMLConverter
|
||||
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
import attention_bias
|
||||
attention_bias.do_patches()
|
||||
|
||||
from modeling import patches
|
||||
patches.patch_transformers_for_lazyload()
|
||||
|
||||
import attention_bias
|
||||
attention_bias.do_patches()
|
||||
|
||||
from os import path, getcwd
|
||||
import time
|
||||
import re
|
||||
|
@@ -9,6 +9,7 @@ import functools
|
||||
import itertools
|
||||
import traceback
|
||||
import contextlib
|
||||
from accelerate.big_modeling import load_checkpoint_and_dispatch
|
||||
from accelerate.utils.modeling import infer_auto_device_map, load_checkpoint_in_model
|
||||
from tqdm.auto import tqdm
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -263,6 +264,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
tf_kwargs["revision"] = utils.koboldai_vars.revision
|
||||
tf_kwargs["cache_dir"] = "cache"
|
||||
|
||||
if self.lazy_load:
|
||||
tf_kwargs.pop("low_cpu_mem_usage", None)
|
||||
|
||||
# If we have model hints for legacy model, use them rather than fall back.
|
||||
try:
|
||||
if self.model_name == "GPT2Custom":
|
||||
@@ -285,17 +289,25 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
# offload_state_dict=True
|
||||
# )
|
||||
# model.tie_weights()
|
||||
no_split_module_classes = ["GPTJBlock", "OPTDecoderLayer"]
|
||||
|
||||
print("[HUGE SKELETON] MAKING DEVICE MAP")
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
|
||||
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
|
||||
no_split_module_classes=no_split_module_classes,
|
||||
dtype="float16",
|
||||
)
|
||||
print("[HUGE SKELETON] TYING WEIGHTS")
|
||||
|
||||
model.tie_weights()
|
||||
|
||||
print("[HUGE SKELETON] LOADING FROM PRETRAINED")
|
||||
return AutoModelForCausalLM.from_pretrained(
|
||||
location, device_map=device_map
|
||||
) # , **tf_kwargs)
|
||||
location,
|
||||
device_map=device_map,
|
||||
**tf_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback_string = traceback.format_exc().lower()
|
||||
|
||||
|
@@ -13,6 +13,7 @@ from transformers import (
|
||||
from modeling.lazy_loader import LazyTensor
|
||||
|
||||
import utils
|
||||
from logger import logger
|
||||
|
||||
|
||||
def patch_transformers_download():
|
||||
@@ -127,8 +128,27 @@ def patch_transformers_generation() -> None:
|
||||
|
||||
|
||||
def patch_transformers_for_lazyload() -> None:
|
||||
"""
|
||||
Most of the code is modified code from the Accelerate and Transformers
|
||||
projects, made by HuggingFace. The license for these projects are as follows:
|
||||
---
|
||||
Copyright The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import torch
|
||||
import inspect
|
||||
import accelerate
|
||||
from accelerate.utils.modeling import named_module_tensors
|
||||
from accelerate.utils import set_module_tensor_to_device, offload_weight
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
@@ -183,10 +203,9 @@ def patch_transformers_for_lazyload() -> None:
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# BEGIN PATCH
|
||||
# BEGIN PATCH
|
||||
# TODO: Based on config
|
||||
dtype = torch.float16
|
||||
set_module_kwargs = {"dtype": dtype}
|
||||
|
||||
for param_name, param in sorted(
|
||||
state_dict.items(),
|
||||
@@ -202,7 +221,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
if isinstance(param, LazyTensor):
|
||||
# Should always be true
|
||||
param = param.materialize()
|
||||
# END PATCH
|
||||
# END PATCH
|
||||
|
||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
||||
if (
|
||||
@@ -228,13 +247,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
# TODO: @sgugger replace this check with version check at the next `accelerate` release
|
||||
if "dtype" in list(
|
||||
inspect.signature(set_module_tensor_to_device).parameters
|
||||
):
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
@@ -250,8 +262,6 @@ def patch_transformers_for_lazyload() -> None:
|
||||
if old_param is not None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
set_module_kwargs["value"] = param
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
@@ -263,6 +273,7 @@ def patch_transformers_for_lazyload() -> None:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
offload_index = offload_weight(
|
||||
@@ -275,7 +286,10 @@ def patch_transformers_for_lazyload() -> None:
|
||||
elif not load_in_8bit:
|
||||
# For backward compatibility with older versions of `accelerate`
|
||||
set_module_tensor_to_device(
|
||||
model, tensor_name=param_name, device=param_device, **set_module_kwargs
|
||||
model,
|
||||
tensor_name=param_name,
|
||||
device=param_device,
|
||||
value=param,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
@@ -301,6 +315,73 @@ def patch_transformers_for_lazyload() -> None:
|
||||
_load_state_dict_into_meta_model
|
||||
)
|
||||
|
||||
# Patch AlignDevicesHook to hack around OPT lm_head
|
||||
HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
|
||||
|
||||
def _recursed_key_value(module, key):
|
||||
"""Gets a tensor from a recursive key (ie with .s)"""
|
||||
if "." in key:
|
||||
splits = key.split(".")
|
||||
for split in splits[:-1]:
|
||||
new_module = getattr(module, split)
|
||||
if new_module is None:
|
||||
raise ValueError(f"{module} has no attribute {split}.")
|
||||
module = new_module
|
||||
key = splits[-1]
|
||||
return getattr(module, key)
|
||||
|
||||
def _init_hook(self, module):
|
||||
if not self.offload and self.execution_device is not None:
|
||||
# BEGIN PATCH
|
||||
for name, tensor in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
):
|
||||
try:
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
except ValueError:
|
||||
# ValueError: weight is on the meta device, we need a `value` to put in on 0.
|
||||
# bleuuuuuuuuuuuuuuuhhh
|
||||
if name in HACK_ZERO_ON_FAIL_TENSORS:
|
||||
logger.warning(f"Couldn't find value for weight {name}, zeroing.")
|
||||
set_module_tensor_to_device(
|
||||
module,
|
||||
name,
|
||||
self.execution_device,
|
||||
value=torch.zeros(tensor.shape),
|
||||
)
|
||||
# END PATCH
|
||||
elif self.offload:
|
||||
self.original_devices = {
|
||||
name: param.device
|
||||
for name, param in named_module_tensors(
|
||||
module, recurse=self.place_submodules
|
||||
)
|
||||
}
|
||||
|
||||
if self.weights_map is None:
|
||||
self.weights_map = {
|
||||
name: param.to("cpu")
|
||||
for name, param in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
)
|
||||
}
|
||||
|
||||
for name, _ in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
recurse=self.place_submodules,
|
||||
):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
return module
|
||||
|
||||
accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
|
||||
|
||||
|
||||
def patch_transformers() -> None:
|
||||
patch_transformers_download()
|
||||
|
Reference in New Issue
Block a user