diff --git a/aiserver.py b/aiserver.py index 72d983ef..83df8c7e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -993,7 +993,7 @@ else: -1, tpu_mtj_backend.params["d_model"], ) - vars.sp = tensor + vars.sp = tpu_mtj_backend.shard_xmap(tensor) soft_tokens = np.arange( tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, @@ -4005,7 +4005,7 @@ def spRequest(filename): -1, tpu_mtj_backend.params["d_model"], ) - vars.sp = np.float32(tensor) + vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor)) else: vars.sp = torch.from_numpy(tensor) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 845edd30..86413022 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,5 +1,5 @@ import multiprocessing -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TypeVar import progressbar import time import os @@ -28,6 +28,31 @@ def show_spinner(): time.sleep(0.1) i += 1 + +__F = TypeVar("__F", bound=Callable) +__T = TypeVar("__T") + +def __move_xmap(f: __F, out_axis: str) -> __F: + return maps.xmap( + f, + in_axes=(["shard", ...], ["batch", ...]), + out_axes=[out_axis, ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) + +def __shard_xmap(batch_dim=1): + xmap = __move_xmap(lambda s, b: s, "shard") + def inner(x: __T) -> __T: + return xmap(x, np.empty(batch_dim)) + return inner + +def __batch_xmap(shard_dim=1): + xmap = __move_xmap(lambda s, b: b, "batch") + def inner(x: __T) -> __T: + return xmap(np.empty(shard_dim), x) + return inner + + def apply_repetition_penalty(logits, tokens, repetition_penalty): ''' This gets called by generate_loop_fn to apply repetition penalty @@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer): key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size - return self.generate_xmap( - self.state, - jnp.array(key.take(batch_size)), - ctx, - np.array(ctx_length, dtype=np.uint32), - np.array(gen_length, dtype=np.uint32), + xargs = ( + shard_xmap(self.state), + batch_xmap(jnp.array(key.take(batch_size))), + batch_xmap(ctx), + batch_xmap(np.array(ctx_length, dtype=np.uint32)), + batch_xmap(np.array(gen_length, dtype=np.uint32)), np.empty((batch_size, numseqs), dtype=np.uint8), - sampler_options, - soft_embeddings, + batch_xmap(sampler_options), + shard_xmap(soft_embeddings), ) + return self.generate_xmap(*xargs) def infer( @@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) maps.thread_resources.env = thread_resources_env tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') + global shard_xmap, batch_xmap + shard_xmap = __shard_xmap() + batch_xmap = __batch_xmap(shard_dim=cores_per_replica) + if not path.endswith("/"): path += "/"