xmaps for moving things onto TPU

This commit is contained in:
Gnome Ann
2022-01-12 21:45:30 -05:00
parent 49e2bcab1a
commit a3d6dc93e8
2 changed files with 41 additions and 11 deletions

View File

@ -1,5 +1,5 @@
import multiprocessing
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, TypeVar
import progressbar
import time
import os
@ -28,6 +28,31 @@ def show_spinner():
time.sleep(0.1)
i += 1
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
def __move_xmap(f: __F, out_axis: str) -> __F:
return maps.xmap(
f,
in_axes=(["shard", ...], ["batch", ...]),
out_axes=[out_axis, ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def __shard_xmap(batch_dim=1):
xmap = __move_xmap(lambda s, b: s, "shard")
def inner(x: __T) -> __T:
return xmap(x, np.empty(batch_dim))
return inner
def __batch_xmap(shard_dim=1):
xmap = __move_xmap(lambda s, b: b, "batch")
def inner(x: __T) -> __T:
return xmap(np.empty(shard_dim), x)
return inner
def apply_repetition_penalty(logits, tokens, repetition_penalty):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer):
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
return self.generate_xmap(
self.state,
jnp.array(key.take(batch_size)),
ctx,
np.array(ctx_length, dtype=np.uint32),
np.array(gen_length, dtype=np.uint32),
xargs = (
shard_xmap(self.state),
batch_xmap(jnp.array(key.take(batch_size))),
batch_xmap(ctx),
batch_xmap(np.array(ctx_length, dtype=np.uint32)),
batch_xmap(np.array(gen_length, dtype=np.uint32)),
np.empty((batch_size, numseqs), dtype=np.uint8),
sampler_options,
soft_embeddings,
batch_xmap(sampler_options),
shard_xmap(soft_embeddings),
)
return self.generate_xmap(*xargs)
def infer(
@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
if not path.endswith("/"):
path += "/"