xmaps for moving things onto TPU

This commit is contained in:
Gnome Ann 2022-01-12 21:45:30 -05:00
parent 49e2bcab1a
commit a3d6dc93e8
2 changed files with 41 additions and 11 deletions

View File

@ -993,7 +993,7 @@ else:
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -4005,7 +4005,7 @@ def spRequest(filename):
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = np.float32(tensor)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)

View File

@ -1,5 +1,5 @@
import multiprocessing
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, TypeVar
import progressbar
import time
import os
@ -28,6 +28,31 @@ def show_spinner():
time.sleep(0.1)
i += 1
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
def __move_xmap(f: __F, out_axis: str) -> __F:
return maps.xmap(
f,
in_axes=(["shard", ...], ["batch", ...]),
out_axes=[out_axis, ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def __shard_xmap(batch_dim=1):
xmap = __move_xmap(lambda s, b: s, "shard")
def inner(x: __T) -> __T:
return xmap(x, np.empty(batch_dim))
return inner
def __batch_xmap(shard_dim=1):
xmap = __move_xmap(lambda s, b: b, "batch")
def inner(x: __T) -> __T:
return xmap(np.empty(shard_dim), x)
return inner
def apply_repetition_penalty(logits, tokens, repetition_penalty):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer):
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
return self.generate_xmap(
self.state,
jnp.array(key.take(batch_size)),
ctx,
np.array(ctx_length, dtype=np.uint32),
np.array(gen_length, dtype=np.uint32),
xargs = (
shard_xmap(self.state),
batch_xmap(jnp.array(key.take(batch_size))),
batch_xmap(ctx),
batch_xmap(np.array(ctx_length, dtype=np.uint32)),
batch_xmap(np.array(gen_length, dtype=np.uint32)),
np.empty((batch_size, numseqs), dtype=np.uint8),
sampler_options,
soft_embeddings,
batch_xmap(sampler_options),
shard_xmap(soft_embeddings),
)
return self.generate_xmap(*xargs)
def infer(
@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
if not path.endswith("/"):
path += "/"