This commit is contained in:
somebody
2023-05-29 13:34:11 -05:00
parent ceaefa9f5e
commit 58ffad237b
3 changed files with 113 additions and 20 deletions

View File

@@ -9,6 +9,7 @@ import functools
import itertools
import traceback
import contextlib
from accelerate.big_modeling import load_checkpoint_and_dispatch
from accelerate.utils.modeling import infer_auto_device_map, load_checkpoint_in_model
from tqdm.auto import tqdm
from typing import Dict, List, Optional, Union
@@ -263,6 +264,9 @@ class HFTorchInferenceModel(HFInferenceModel):
tf_kwargs["revision"] = utils.koboldai_vars.revision
tf_kwargs["cache_dir"] = "cache"
if self.lazy_load:
tf_kwargs.pop("low_cpu_mem_usage", None)
# If we have model hints for legacy model, use them rather than fall back.
try:
if self.model_name == "GPT2Custom":
@@ -285,17 +289,25 @@ class HFTorchInferenceModel(HFInferenceModel):
# offload_state_dict=True
# )
# model.tie_weights()
no_split_module_classes = ["GPTJBlock", "OPTDecoderLayer"]
print("[HUGE SKELETON] MAKING DEVICE MAP")
device_map = infer_auto_device_map(
model,
max_memory={0: "10GiB", 1: "7GiB", "cpu": "15GiB"},
no_split_module_classes=["GPTJBlock", "OPTDecoderLayer"],
no_split_module_classes=no_split_module_classes,
dtype="float16",
)
print("[HUGE SKELETON] TYING WEIGHTS")
model.tie_weights()
print("[HUGE SKELETON] LOADING FROM PRETRAINED")
return AutoModelForCausalLM.from_pretrained(
location, device_map=device_map
) # , **tf_kwargs)
location,
device_map=device_map,
**tf_kwargs,
)
except Exception as e:
traceback_string = traceback.format_exc().lower()