mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Support for loading HF models on TPU with --colab_tpu
This commit is contained in:
@ -32,6 +32,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
import progressbar
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import zipfile
|
||||
import requests
|
||||
import random
|
||||
import jax
|
||||
@ -41,9 +44,10 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import haiku as hk
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard
|
||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
||||
from mesh_transformer.util import to_bf16
|
||||
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
@ -776,7 +780,26 @@ def infer_static(
|
||||
return samples
|
||||
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
|
||||
def reshard_reverse(x, total_shards, old_shape):
|
||||
assert len(x.shape) != 1
|
||||
if len(x.shape) == 2:
|
||||
if old_shape[1] == x.shape[1]:
|
||||
out = x[0:1].tile((total_shards, 1))
|
||||
else:
|
||||
out = x.reshape(old_shape)
|
||||
elif len(x.shape) == 3:
|
||||
if x.shape[0] * x.shape[2] == old_shape[2]:
|
||||
out = x.reshape(old_shape)
|
||||
elif x.shape[0] * x.shape[1] == old_shape[1]:
|
||||
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
assert False
|
||||
return out
|
||||
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params
|
||||
|
||||
default_params = {
|
||||
@ -795,6 +818,53 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
params = kwargs
|
||||
|
||||
# Try to convert HF config.json to MTJ config
|
||||
if hf_checkpoint:
|
||||
spec_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "maps", vars.model_type + ".json")
|
||||
if not os.path.isfile(spec_path):
|
||||
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
|
||||
with open(spec_path) as f:
|
||||
lazy_load_spec = json.load(f)
|
||||
|
||||
if "mtj_compat" in lazy_load_spec:
|
||||
params["compat"] = lazy_load_spec["mtj_compat"]
|
||||
if "mtj_pe" in lazy_load_spec:
|
||||
params["pe"] = lazy_load_spec["mtj_pe"]
|
||||
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
|
||||
if type(v) is not list:
|
||||
params[k] = params[v]
|
||||
continue
|
||||
for i in range(len(v)):
|
||||
if i == len(v) - 1:
|
||||
params[k] = v[i]
|
||||
elif v[i] in params:
|
||||
params[k] = params[v[i]]
|
||||
break
|
||||
|
||||
params["n_vocab"] = params["vocab_size"]
|
||||
|
||||
if "activation_function" in params:
|
||||
params["activation"] = params["activation_function"]
|
||||
|
||||
# Both the number of attention heads in the model and the embedding
|
||||
# dimension of the model need to be divisible by the number of TPU cores
|
||||
# that we use, and JAX also requires the number of TPU cores used to be
|
||||
# an even number if we're using more than one core, so logically we try
|
||||
# to pick the largest possible even number of TPU cores such that the
|
||||
# number of attention heads and embedding dimension are both divisible
|
||||
# by the number of TPU cores, and fall back to one core if an even
|
||||
# number of TPU cores is not possible.
|
||||
for c in (8, 6, 4, 2, 1):
|
||||
if 0 == params["n_heads"] % c == params["d_model"] % c:
|
||||
params["cores_per_replica"] = c
|
||||
break
|
||||
|
||||
# The vocabulary size of the model also has to be divisible by the
|
||||
# number of TPU cores, so we pad the vocabulary with the minimum
|
||||
# possible number of dummy tokens such that it's divisible.
|
||||
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
|
||||
|
||||
if "compat" in params:
|
||||
default_params["compat"] = params["compat"]
|
||||
if default_params["compat"] == "fairseq_lm":
|
||||
@ -804,10 +874,11 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
params[param] = default_params[param]
|
||||
|
||||
# Load tokenizer
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
if not hf_checkpoint:
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
|
||||
# Disable JAX warnings about these two functions having been renamed
|
||||
jax.host_count = jax.process_count
|
||||
@ -844,5 +915,147 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
path += "/"
|
||||
|
||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
|
||||
if not hf_checkpoint:
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
return
|
||||
|
||||
# Convert from HF checkpoint
|
||||
|
||||
move_xmap = jax.experimental.maps.xmap(
|
||||
fun=lambda x, _: to_bf16(x),
|
||||
in_axes=(["shard", ...], ["batch", ...]),
|
||||
out_axes=["shard", ...],
|
||||
axis_resources={'shard': 'mp', 'batch': 'dp'}
|
||||
)
|
||||
|
||||
model_spec = {}
|
||||
for key, spec in lazy_load_spec.get("static_weights", {}).items():
|
||||
if spec.get("mtj") is not None:
|
||||
model_spec[key] = spec["mtj"].copy()
|
||||
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
|
||||
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
|
||||
for layer in range(params["layers"]):
|
||||
if spec.get("mtj") is not None:
|
||||
key = _key.format(layer=layer)
|
||||
model_spec[key] = spec["mtj"].copy()
|
||||
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
|
||||
|
||||
import torch_lazy_loader
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
def callback(model_dict, f, **_):
|
||||
with zipfile.ZipFile(f, "r") as z:
|
||||
try:
|
||||
last_storage_key = None
|
||||
f = None
|
||||
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
|
||||
for key in tqdm(sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
|
||||
# Some model weights are used by transformers but not by MTJ.
|
||||
# We have to materialize these weights anyways because
|
||||
# transformers will throw a tantrum otherwise. To attain
|
||||
# the least possible memory usage, we create them as meta
|
||||
# tensors, which don't take up any actual CPU or TPU memory.
|
||||
if key not in model_spec:
|
||||
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].storage_type(0).dtype, device="meta")
|
||||
continue
|
||||
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key:
|
||||
last_storage_key = storage_key
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
f = z.open(f"archive/data/{storage_key}")
|
||||
current_offset = f.tell()
|
||||
if current_offset != model_dict[key].seek_offset:
|
||||
f.seek(model_dict[key].seek_offset - current_offset, 1)
|
||||
spec = model_spec[key]
|
||||
transforms = set(spec.get("transforms", ()))
|
||||
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
|
||||
error = f"Duplicate key {repr(key)}"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
tensor = model_dict[key].materialize(f, map_location="cpu")
|
||||
model_dict[key] = tensor.to("meta")
|
||||
|
||||
# MTJ requires certain mathematical operations to be performed
|
||||
# on tensors in order for them to be in the correct format
|
||||
if "divide_by_shards" in transforms:
|
||||
tensor /= params["cores_per_replica"]
|
||||
if "vocab_pad" in transforms:
|
||||
tensor = torch.nn.functional.pad(tensor, (0, 0, 0, params["n_vocab_padding"]))
|
||||
if "no_transpose" not in transforms:
|
||||
tensor = tensor.T
|
||||
tensor.unsqueeze_(0)
|
||||
|
||||
# Shard the tensor so that parts of the tensor can be used
|
||||
# on different TPU cores
|
||||
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
|
||||
jnp.array(
|
||||
reshard_reverse(
|
||||
tensor,
|
||||
params["cores_per_replica"],
|
||||
network.state["params"][spec["module"]][spec["param"]].shape,
|
||||
),
|
||||
dtype=jnp.bfloat16,
|
||||
),
|
||||
np.empty(params["cores_per_replica"]),
|
||||
)
|
||||
|
||||
# Check for tensors that MTJ needs that were not provided in the
|
||||
# HF model
|
||||
for mk, mv in network.state["params"].items():
|
||||
for pk, pv in mv.items():
|
||||
if isinstance(pv, PlaceholderTensor):
|
||||
# The transformers GPT-J models apparently do not
|
||||
# have embedding bias, whereas MTJ GPT-J models do,
|
||||
# so we have to supplement an embedding bias tensor
|
||||
# by creating a tensor with the necessary shape, filled
|
||||
# with zeros.
|
||||
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
|
||||
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
|
||||
|
||||
else:
|
||||
error = f"{mk} {pk} could not be found in the model checkpoint"
|
||||
print("\n\nERROR: " + error, file=sys.stderr)
|
||||
raise RuntimeError(error)
|
||||
finally:
|
||||
if isinstance(f, zipfile.ZipExtFile):
|
||||
f.close()
|
||||
|
||||
if os.path.isdir(vars.model.replace('/', '_')):
|
||||
import shutil
|
||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
|
||||
if(os.path.isdir(vars.custmodpth)):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache")
|
||||
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache")
|
||||
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
|
Reference in New Issue
Block a user