xmaps for moving things onto TPU

This commit is contained in:
Gnome Ann 2022-01-12 21:45:30 -05:00
parent 49e2bcab1a
commit a3d6dc93e8
2 changed files with 41 additions and 11 deletions

View File

@ -993,7 +993,7 @@ else:
-1, -1,
tpu_mtj_backend.params["d_model"], tpu_mtj_backend.params["d_model"],
) )
vars.sp = tensor vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange( soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -4005,7 +4005,7 @@ def spRequest(filename):
-1, -1,
tpu_mtj_backend.params["d_model"], tpu_mtj_backend.params["d_model"],
) )
vars.sp = np.float32(tensor) vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else: else:
vars.sp = torch.from_numpy(tensor) vars.sp = torch.from_numpy(tensor)

View File

@ -1,5 +1,5 @@
import multiprocessing import multiprocessing
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, TypeVar
import progressbar import progressbar
import time import time
import os import os
@ -28,6 +28,31 @@ def show_spinner():
time.sleep(0.1) time.sleep(0.1)
i += 1 i += 1
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
def __move_xmap(f: __F, out_axis: str) -> __F:
return maps.xmap(
f,
in_axes=(["shard", ...], ["batch", ...]),
out_axes=[out_axis, ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def __shard_xmap(batch_dim=1):
xmap = __move_xmap(lambda s, b: s, "shard")
def inner(x: __T) -> __T:
return xmap(x, np.empty(batch_dim))
return inner
def __batch_xmap(shard_dim=1):
xmap = __move_xmap(lambda s, b: b, "batch")
def inner(x: __T) -> __T:
return xmap(np.empty(shard_dim), x)
return inner
def apply_repetition_penalty(logits, tokens, repetition_penalty): def apply_repetition_penalty(logits, tokens, repetition_penalty):
''' '''
This gets called by generate_loop_fn to apply repetition penalty This gets called by generate_loop_fn to apply repetition penalty
@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer):
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
self.batch_size = batch_size self.batch_size = batch_size
return self.generate_xmap( xargs = (
self.state, shard_xmap(self.state),
jnp.array(key.take(batch_size)), batch_xmap(jnp.array(key.take(batch_size))),
ctx, batch_xmap(ctx),
np.array(ctx_length, dtype=np.uint32), batch_xmap(np.array(ctx_length, dtype=np.uint32)),
np.array(gen_length, dtype=np.uint32), batch_xmap(np.array(gen_length, dtype=np.uint32)),
np.empty((batch_size, numseqs), dtype=np.uint8), np.empty((batch_size, numseqs), dtype=np.uint8),
sampler_options, batch_xmap(sampler_options),
soft_embeddings, shard_xmap(soft_embeddings),
) )
return self.generate_xmap(*xargs)
def infer( def infer(
@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
if not path.endswith("/"): if not path.endswith("/"):
path += "/" path += "/"