This commit is contained in:
somebody
2023-05-29 13:34:11 -05:00
parent ceaefa9f5e
commit 58ffad237b
3 changed files with 113 additions and 20 deletions

View File

@@ -27,12 +27,12 @@ from ansi2html import Ansi2HTMLConverter
logging.getLogger("urllib3").setLevel(logging.ERROR)
import attention_bias
attention_bias.do_patches()
from modeling import patches
patches.patch_transformers_for_lazyload()
import attention_bias
attention_bias.do_patches()
from os import path, getcwd
import time
import re

View File

@@ -9,6 +9,7 @@ import functools
import itertools
import traceback
import contextlib
from accelerate.big_modeling import load_checkpoint_and_dispatch
from accelerate.utils.modeling import infer_auto_device_map, load_checkpoint_in_model
from tqdm.auto import tqdm
from typing import Dict, List, Optional, Union
@@ -263,6 +264,9 @@ class HFTorchInferenceModel(HFInferenceModel):
tf_kwargs["revision"] = utils.koboldai_vars.revision
tf_kwargs["cache_dir"] = "cache"
if self.lazy_load:
tf_kwargs.pop("low_cpu_mem_usage", None)
# If we have model hints for legacy model, use them rather than fall back.
try:
if self.model_name == "GPT2Custom":
@@ -285,17 +289,25 @@ class HFTorchInferenceModel(HFInferenceModel):
# offload_state_dict=True
# )
# model.tie_weights()
no_split_module_classes = ["GPTJBlock", "OPTDecoderLayer"]
print("[HUGE SKELETON] MAKING DEVICE MAP")
device_map = infer_auto_device_map(
model,
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
no_split_module_classes=no_split_module_classes,
dtype="float16",
)
print("[HUGE SKELETON] TYING WEIGHTS")
model.tie_weights()
print("[HUGE SKELETON] LOADING FROM PRETRAINED")
return AutoModelForCausalLM.from_pretrained(
location, device_map=device_map
) # , **tf_kwargs)
location,
device_map=device_map,
**tf_kwargs,
)
except Exception as e:
traceback_string = traceback.format_exc().lower()

View File

@@ -13,6 +13,7 @@ from transformers import (
from modeling.lazy_loader import LazyTensor
import utils
from logger import logger
def patch_transformers_download():
@@ -127,8 +128,27 @@ def patch_transformers_generation() -> None:
def patch_transformers_for_lazyload() -> None:
"""
Most of the code is modified code from the Accelerate and Transformers
projects, made by HuggingFace. The license for these projects are as follows:
---
Copyright The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import inspect
import accelerate
from accelerate.utils.modeling import named_module_tensors
from accelerate.utils import set_module_tensor_to_device, offload_weight
def _load_state_dict_into_meta_model(
@@ -183,10 +203,9 @@ def patch_transformers_for_lazyload() -> None:
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# BEGIN PATCH
# BEGIN PATCH
# TODO: Based on config
dtype = torch.float16
set_module_kwargs = {"dtype": dtype}
for param_name, param in sorted(
state_dict.items(),
@@ -202,7 +221,7 @@ def patch_transformers_for_lazyload() -> None:
if isinstance(param, LazyTensor):
# Should always be true
param = param.materialize()
# END PATCH
# END PATCH
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if (
@@ -228,13 +247,6 @@ def patch_transformers_for_lazyload() -> None:
and dtype == torch.float16
):
param = param.to(torch.float32)
# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(
inspect.signature(set_module_tensor_to_device).parameters
):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
@@ -250,8 +262,6 @@ def patch_transformers_for_lazyload() -> None:
if old_param is not None:
param = param.to(old_param.dtype)
set_module_kwargs["value"] = param
if device_map is None:
param_device = "cpu"
else:
@@ -263,6 +273,7 @@ def patch_transformers_for_lazyload() -> None:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(
@@ -275,7 +286,10 @@ def patch_transformers_for_lazyload() -> None:
elif not load_in_8bit:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(
model, tensor_name=param_name, device=param_device, **set_module_kwargs
model,
tensor_name=param_name,
device=param_device,
value=param,
)
else:
if (
@@ -301,6 +315,73 @@ def patch_transformers_for_lazyload() -> None:
_load_state_dict_into_meta_model
)
# Patch AlignDevicesHook to hack around OPT lm_head
HACK_ZERO_ON_FAIL_TENSORS = ["lm_head.weight"]
def _recursed_key_value(module, key):
"""Gets a tensor from a recursive key (ie with .s)"""
if "." in key:
splits = key.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
key = splits[-1]
return getattr(module, key)
def _init_hook(self, module):
if not self.offload and self.execution_device is not None:
# BEGIN PATCH
for name, tensor in named_module_tensors(
module, recurse=self.place_submodules
):
try:
set_module_tensor_to_device(module, name, self.execution_device)
except ValueError:
# ValueError: weight is on the meta device, we need a `value` to put in on 0.
# bleuuuuuuuuuuuuuuuhhh
if name in HACK_ZERO_ON_FAIL_TENSORS:
logger.warning(f"Couldn't find value for weight {name}, zeroing.")
set_module_tensor_to_device(
module,
name,
self.execution_device,
value=torch.zeros(tensor.shape),
)
# END PATCH
elif self.offload:
self.original_devices = {
name: param.device
for name, param in named_module_tensors(
module, recurse=self.place_submodules
)
}
if self.weights_map is None:
self.weights_map = {
name: param.to("cpu")
for name, param in named_module_tensors(
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
)
}
for name, _ in named_module_tensors(
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
):
set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
return module
accelerate.hooks.AlignDevicesHook.init_hook = _init_hook
def patch_transformers() -> None:
patch_transformers_download()