From a3d6dc93e83a2924be08ce556f01506461eeea9b Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 12 Jan 2022 21:45:30 -0500 Subject: [PATCH 1/8] xmaps for moving things onto TPU --- aiserver.py | 4 ++-- tpu_mtj_backend.py | 48 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/aiserver.py b/aiserver.py index 72d983ef..83df8c7e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -993,7 +993,7 @@ else: -1, tpu_mtj_backend.params["d_model"], ) - vars.sp = tensor + vars.sp = tpu_mtj_backend.shard_xmap(tensor) soft_tokens = np.arange( tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, @@ -4005,7 +4005,7 @@ def spRequest(filename): -1, tpu_mtj_backend.params["d_model"], ) - vars.sp = np.float32(tensor) + vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor)) else: vars.sp = torch.from_numpy(tensor) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 845edd30..86413022 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,5 +1,5 @@ import multiprocessing -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TypeVar import progressbar import time import os @@ -28,6 +28,31 @@ def show_spinner(): time.sleep(0.1) i += 1 + +__F = TypeVar("__F", bound=Callable) +__T = TypeVar("__T") + +def __move_xmap(f: __F, out_axis: str) -> __F: + return maps.xmap( + f, + in_axes=(["shard", ...], ["batch", ...]), + out_axes=[out_axis, ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) + +def __shard_xmap(batch_dim=1): + xmap = __move_xmap(lambda s, b: s, "shard") + def inner(x: __T) -> __T: + return xmap(x, np.empty(batch_dim)) + return inner + +def __batch_xmap(shard_dim=1): + xmap = __move_xmap(lambda s, b: b, "batch") + def inner(x: __T) -> __T: + return xmap(np.empty(shard_dim), x) + return inner + + def apply_repetition_penalty(logits, tokens, repetition_penalty): ''' This gets called by generate_loop_fn to apply repetition penalty @@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer): key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size - return self.generate_xmap( - self.state, - jnp.array(key.take(batch_size)), - ctx, - np.array(ctx_length, dtype=np.uint32), - np.array(gen_length, dtype=np.uint32), + xargs = ( + shard_xmap(self.state), + batch_xmap(jnp.array(key.take(batch_size))), + batch_xmap(ctx), + batch_xmap(np.array(ctx_length, dtype=np.uint32)), + batch_xmap(np.array(gen_length, dtype=np.uint32)), np.empty((batch_size, numseqs), dtype=np.uint8), - sampler_options, - soft_embeddings, + batch_xmap(sampler_options), + shard_xmap(soft_embeddings), ) + return self.generate_xmap(*xargs) def infer( @@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) maps.thread_resources.env = thread_resources_env tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') + global shard_xmap, batch_xmap + shard_xmap = __shard_xmap() + batch_xmap = __batch_xmap(shard_dim=cores_per_replica) + if not path.endswith("/"): path += "/" From 09c4fdcb2e57508d5c35b30f4c2a3b301f0f0c02 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Thu, 13 Jan 2022 00:56:00 -0500 Subject: [PATCH 2/8] Split `generate_xmap` into two xmaps --- tpu_mtj_backend.py | 56 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 86413022..4a9833aa 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -180,11 +180,9 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + def generate_initial(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): numseqs = numseqs_aux.shape[0] - # These are the tokens that we don't want the AI to ever write - self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) - def generate_sample(context, ctx_length): + def generate_initial_inner(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) def generate_initial_scan_fn(sequence_index, _): @@ -201,6 +199,32 @@ class PenalizingCausalTransformer(CausalTransformer): _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) sample_key = initial_states[-1][0] initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) + return initial_states, sample_key + generate_initial_fn = hk.transform(generate_initial_inner).apply + return generate_initial_fn(state["params"], key, ctx, ctx_length) + self.generate_initial_xmap = jax.experimental.maps.xmap( + fun=generate_initial, + in_axes=( + ["shard", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["shard", ...], + ), + out_axes=["shard", "batch", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) + def generate_once(initial_states, sample_key, state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + numseqs = numseqs_aux.shape[0] + # These are the tokens that we don't want the AI to ever write + self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) + def generate_once_inner(context, ctx_length): + gi = initial_states[0][1] + # Give the initial context to the transformer + transformer = CausalTransformerShard(config) # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) # This is the main generation loop @@ -253,16 +277,18 @@ class PenalizingCausalTransformer(CausalTransformer): carry[0].append(carry[0].pop(0)) return carry[0], new_key final_state = jax.lax.while_loop( - lambda carry: carry[0][0][1] - config["seq"] < gen_length, + lambda carry: carry[0][0][1] == gi, generate_loop_fn, (initial_states, sample_key), ) return final_state - generate_fn = hk.transform(generate_sample).apply - return generate_fn(state["params"], key, ctx, ctx_length) - self.generate_xmap = jax.experimental.maps.xmap( - fun=generate, + generate_once_fn = hk.transform(generate_once_inner).apply + return generate_once_fn(state["params"], key, ctx, ctx_length) + self.generate_once_xmap = jax.experimental.maps.xmap( + fun=generate_once, in_axes=( + ["shard", "batch", ...], + ["shard", "batch", ...], ["shard", ...], ["batch", ...], ["batch", ...], @@ -277,11 +303,12 @@ class PenalizingCausalTransformer(CausalTransformer): ) def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits + assert gen_length.ndim == 1 key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size - xargs = ( - shard_xmap(self.state), + xargs = [ + self.state, batch_xmap(jnp.array(key.take(batch_size))), batch_xmap(ctx), batch_xmap(np.array(ctx_length, dtype=np.uint32)), @@ -289,8 +316,11 @@ class PenalizingCausalTransformer(CausalTransformer): np.empty((batch_size, numseqs), dtype=np.uint8), batch_xmap(sampler_options), shard_xmap(soft_embeddings), - ) - return self.generate_xmap(*xargs) + ] + initial_state, sample_key = self.generate_initial_xmap(*xargs) + for i in range(gen_length[0]): + initial_state, sample_key = self.generate_once_xmap(initial_state, sample_key, *xargs) + return initial_state, sample_key def infer( From 57a6886007bc30e8523aae4b3a2ac851968fc0e9 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 02:23:19 -0500 Subject: [PATCH 3/8] Move sampling into a `jax.jit`ted function --- tpu_mtj_backend.py | 184 +++++++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 83 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 4a9833aa..f5d70e6b 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -172,36 +172,74 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] + return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 +def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): + numseqs = numseqs_aux.shape[0] + gi = data[0][1] + def sample_loop_fn(carry): + generated, generated_index, logits, _ = carry[0][0] + sample_key = carry[1] + # Get the pseudo-random number generator key that will + # be used by kobold_sample to randomly pick a token + sample_key, new_key = jax.random.split(sample_key, num=2) + # Apply repetition penalty to all tokens that are + # currently inside the "generated" array + logits = apply_repetition_penalty( + logits, + generated, + repetition_penalty + ) + # Remove any tokens in the badwords list by setting + # their logits to negative infinity which effectively + # makes their probabilities of being chosen zero + logits = logits.at[badwords].set(-jnp.inf) + # Use the sampler (kobold_sample) to pick one token + # based on the logits array as a 0D uint32 array + # (higher logit means higher probability of being + # picked, non-linearly) + next_token = kobold_sample( + sample_key, + logits, + **sampler_options, + ) + # Remember what token was picked + generated = generated.at[generated_index].set(next_token) + generated_index += 1 + # Re-pack the current sample_loop_fn's state so we can + # get back the same variables the next time + carry[0][0] = [generated, generated_index, logits, next_token] + carry[0].append(carry[0].pop(0)) + return carry[0], new_key + return jax.lax.while_loop( + lambda carry: carry[0][0][1] < gi, + sample_loop_fn, + (data, key), + ) + class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate_initial(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): numseqs = numseqs_aux.shape[0] + @hk.transform def generate_initial_inner(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) def generate_initial_scan_fn(sequence_index, _): _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) - # The "generated" array will contain the tokens from the - # context as well as the tokens picked by the sampler at - # each stage, padded with a bunch of 50256s, so we know - # which tokens have to be repetition penalized - generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens generated_index = config["seq"] # Add that information to generate_loop_fn's starting state - initial_state = (generated, generated_index, sequence_index) + initial_state + initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state return sequence_index+1, initial_state _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) sample_key = initial_states[-1][0] initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) return initial_states, sample_key - generate_initial_fn = hk.transform(generate_initial_inner).apply - return generate_initial_fn(state["params"], key, ctx, ctx_length) + return generate_initial_inner.apply(state["params"], key, ctx, ctx_length) self.generate_initial_xmap = jax.experimental.maps.xmap( fun=generate_initial, in_axes=( @@ -210,31 +248,23 @@ class PenalizingCausalTransformer(CausalTransformer): ["batch", ...], ["batch", ...], ["batch", ...], - ["batch", ...], - ["batch", ...], ["shard", ...], ), out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate_once(initial_states, sample_key, state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + def generate_once(data, state, numseqs_aux, soft_embeddings=None): numseqs = numseqs_aux.shape[0] - # These are the tokens that we don't want the AI to ever write - self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) - def generate_once_inner(context, ctx_length): - gi = initial_states[0][1] + @hk.without_apply_rng + @hk.transform + def generate_once_inner(): + gi = data[0][1] # Give the initial context to the transformer transformer = CausalTransformerShard(config) - # Get repetition penalty from the arguments - repetition_penalty = sampler_options.pop('repetition_penalty', None) # This is the main generation loop def generate_loop_fn(carry): # Unpack current generate_loop_fn state - generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] - sample_key = carry[1] - # Get the pseudo-random number generator key that will - # be used by kobold_sample to randomly pick a token - sample_key, new_key = jax.random.split(sample_key) + _, generated_index, sequence_index, next_token, decode_state = carry[0][0] # Give the context to the model and get the logits it # spits out # (a 2D array with 1 row and 50400 columns representing @@ -245,57 +275,27 @@ class PenalizingCausalTransformer(CausalTransformer): # Verify that logits does indeed have that many rows and # columns (if you get an error here, pray for mercy) assert logits.shape == (1, config["n_vocab"]) + assert logits.dtype == jnp.float32 # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty( - logits, - generated, - repetition_penalty - ) - # Remove any tokens in the badwords list by setting - # their logits to negative infinity which effectively - # makes their probabilities of being chosen zero - logits = logits.at[self.badwords].set(-jnp.inf) - # Use the sampler (kobold_sample) to pick one token - # based on the logits array as a 1D array with 1 element - # (higher logit means higher probability of being - # picked, non-linearly) - next_token = kobold_sample( - sample_key, - logits, - **sampler_options, - ) - # Remember what token was picked - generated = generated.at[generated_index].set(next_token[0]) - generated_index += 1 # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time - carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) + generated_index += 1 + carry[0][0] = (logits, generated_index, sequence_index, next_token, new_state) carry[0].append(carry[0].pop(0)) - return carry[0], new_key - final_state = jax.lax.while_loop( + return carry[0], + return jax.lax.while_loop( lambda carry: carry[0][0][1] == gi, generate_loop_fn, - (initial_states, sample_key), + (data,), ) - return final_state - generate_once_fn = hk.transform(generate_once_inner).apply - return generate_once_fn(state["params"], key, ctx, ctx_length) + return generate_once_inner.apply(state["params"]) self.generate_once_xmap = jax.experimental.maps.xmap( fun=generate_once, in_axes=( - ["shard", "batch", ...], ["shard", "batch", ...], ["shard", ...], ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], ["shard", ...], ), out_axes=["shard", "batch", ...], @@ -304,23 +304,30 @@ class PenalizingCausalTransformer(CausalTransformer): def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits assert gen_length.ndim == 1 + assert soft_embeddings is not None key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size - xargs = [ - self.state, - batch_xmap(jnp.array(key.take(batch_size))), - batch_xmap(ctx), - batch_xmap(np.array(ctx_length, dtype=np.uint32)), - batch_xmap(np.array(gen_length, dtype=np.uint32)), - np.empty((batch_size, numseqs), dtype=np.uint8), - batch_xmap(sampler_options), - shard_xmap(soft_embeddings), + _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) + numseqs_aux = batch_xmap(_numseqs_aux) + sample_data = [ + [ + jnp.pad(ctx, (0, params["seq"]), constant_values=pad_token_id), + params["seq"], + None, + jnp.empty((), dtype=jnp.uint32), + ] + for _ in range(numseqs) ] - initial_state, sample_key = self.generate_initial_xmap(*xargs) - for i in range(gen_length[0]): - initial_state, sample_key = self.generate_once_xmap(initial_state, sample_key, *xargs) - return initial_state, sample_key + repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) + generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) + sample_key = jax.device_put(sample_key[0, 0], cpu) + for _ in range(gen_length[0].item()): + generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) + for i in range(numseqs): + sample_data[i][2] = jax.device_put(generate_data[0][i][0, 0], cpu) + sample_data, sample_key = sample_jit(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + return sample_data, sample_key def infer( @@ -345,23 +352,23 @@ def infer( padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) samples = [] - batched_generator_params = { - "temp": temp * np.ones(total_batch), - "top_p": top_p * np.ones(total_batch), - "tfs": tfs * np.ones(total_batch), - "repetition_penalty": repetition_penalty * np.ones(total_batch), - "top_k": np.full(total_batch, top_k, dtype=np.uint32) + generator_params = { + "temp": float(temp), + "top_p": float(top_p), + "tfs": float(tfs), + "repetition_penalty": float(repetition_penalty), + "top_k": int(top_k), } output = network.generate( batched_tokens, np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * gen_len, numseqs, - batched_generator_params, + generator_params, soft_embeddings=soft_embeddings, )[0] - for o in output: - samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) + for out in output: + samples.append(out[0][0, 0, params["seq"] : params["seq"] + gen_len]) return samples @@ -414,6 +421,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) shard_xmap = __shard_xmap() batch_xmap = __batch_xmap(shard_dim=cores_per_replica) + global cpu, sample_jit + cpu = jax.devices("cpu")[0] + sample_jit = jax.jit( + sample_jit, + device=cpu, + ) + + global badwords + # These are the tokens that we don't want the AI to ever write + badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) + if not path.endswith("/"): path += "/" From 0bef92419b653f771d99cce7c1148753c9836911 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 15:05:21 -0500 Subject: [PATCH 4/8] Convert the `jit`ted function into ordinary NumPy operations --- tpu_mtj_backend.py | 146 ++++++++++++++++++++++++--------------------- 1 file changed, 77 insertions(+), 69 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f5d70e6b..c2caf270 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -63,19 +63,20 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): # logits array; e.g. # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] - penalty_logits = jnp.take(logits, tokens) + penalty_logits = np.take(logits, tokens) # Divide positive values by repetition_penalty and multiply negative # values by repetition_penalty (the academic publication that described # this technique actually just only divided, but that would cause tokens # with negative logits to become more likely, which is obviously wrong) - penalty_logits = jnp.where( + penalty_logits = np.where( penalty_logits > 0, penalty_logits/repetition_penalty, penalty_logits*repetition_penalty, ) # Finally, put those penalized logit values back into their original # positions in the logits array - return logits.at[tokens].set(penalty_logits) + logits[tokens] = penalty_logits + return logits def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' @@ -91,15 +92,16 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # in the sorted logits array we want to remove and False for ones # we want to keep, in this case the first top_k elements will be # False and the rest will be True - sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k + sorted_indices_to_remove = np.arange(len(logits)) >= top_k # Unsort the logits array back to its original configuration and # remove tokens we need to remove _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), + np.argsort(-logits), sorted_indices_to_remove, ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) + return np.where(indices_to_remove, -np.inf, logits) + if top_k > 0: + logits = top_k_filter(logits) # Top-p (after sorting the remaining tokens again in descending order of # logit, remove the ones that have cumulative softmax probability # greater than p) @@ -108,75 +110,75 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # with e (Euler's number) to the power of that element, and divide # each element of the new array by the sum of the elements in the # new array - sorted_logits = -jnp.sort(-logits) - probabilities = jax.nn.softmax(sorted_logits) + sorted_logits = -np.sort(-logits) + probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) # Calculate cumulative_probabilities as the prefix-sum array of # probabilities - cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) + cumulative_probabilities = np.cumsum(probabilities, axis=-1) # We want to remove tokens with cumulative probability higher # than top_p sorted_indices_to_remove = cumulative_probabilities > top_p # Don't ever remove the token with the highest logit, even if # the probability is higher than top_p - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + sorted_indices_to_remove[0] = False # Unsort and remove _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), + np.argsort(-logits), sorted_indices_to_remove, ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) + return np.where(indices_to_remove, -np.inf, logits) + if top_p < 1.0: + logits = top_p_filter(logits) # Tail free sampling (basically top-p a second time on remaining tokens # except it's the "cumulative normalized absolute second finite # differences of the softmax probabilities" instead of just the # cumulative softmax probabilities) def tail_free_filter(logits): # Sort in descending order - sorted_logits = -jnp.sort(-logits) + sorted_logits = -np.sort(-logits) # Softmax again - probabilities = jax.nn.softmax(sorted_logits) + probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) # Calculate the second finite differences of that array (i.e. # calculate the difference array and then calculate the difference # array of the difference array) - d2 = jnp.diff(jnp.diff(probabilities)) + d2 = np.diff(np.diff(probabilities)) # Get the absolute values of all those second finite differences - d2 = jnp.abs(d2) + d2 = np.abs(d2) # Normalize (all elements in the array are divided by the sum of the # array's elements) d2 = d2 / d2.sum(axis=-1, keepdims=True) # Get the prefix-sum array - cumulative_d2 = jnp.cumsum(d2, axis=-1) + cumulative_d2 = np.cumsum(d2, axis=-1) # We will remove the tokens with a cumulative normalized absolute # second finite difference larger than the TFS value sorted_indices_to_remove = cumulative_d2 > tfs # Don't remove the token with the highest logit - sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + sorted_indices_to_remove[0] = False # Since the d2 array has two fewer elements than the logits array, # we'll add two extra Trues to the end - sorted_indices_to_remove = jnp.pad( + sorted_indices_to_remove = np.pad( sorted_indices_to_remove, (0, 2), constant_values=True, ) # Unsort and remove _, indices_to_remove = jax.lax.sort_key_val( - jnp.argsort(-logits), + np.argsort(-logits), sorted_indices_to_remove, ) - return jnp.where(indices_to_remove, -jnp.inf, logits) - logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) + return np.where(indices_to_remove, -np.inf, logits) + if tfs < 1.0: + logits = tail_free_filter(logits) # Temperature (just divide the logits by the temperature) - def temp_filter(logits): - return logits / temp - logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) + logits /= temp # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - return jax.random.categorical(key, logits, -1).astype(jnp.uint32) + return jax.random.categorical(key, logits, -1).astype(np.uint32) pad_token_id = 50256 -def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): +def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): numseqs = numseqs_aux.shape[0] gi = data[0][1] def sample_loop_fn(carry): @@ -195,7 +197,7 @@ def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_opt # Remove any tokens in the badwords list by setting # their logits to negative infinity which effectively # makes their probabilities of being chosen zero - logits = logits.at[badwords].set(-jnp.inf) + logits[badwords] = -np.inf # Use the sampler (kobold_sample) to pick one token # based on the logits array as a 0D uint32 array # (higher logit means higher probability of being @@ -206,18 +208,22 @@ def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_opt **sampler_options, ) # Remember what token was picked - generated = generated.at[generated_index].set(next_token) + generated[generated_index] = next_token generated_index += 1 # Re-pack the current sample_loop_fn's state so we can # get back the same variables the next time carry[0][0] = [generated, generated_index, logits, next_token] carry[0].append(carry[0].pop(0)) return carry[0], new_key - return jax.lax.while_loop( - lambda carry: carry[0][0][1] < gi, - sample_loop_fn, - (data, key), - ) + # return jax.lax.while_loop( + # lambda carry: carry[0][0][1] == gi, + # sample_loop_fn, + # (data, key), + # ) + carry = (data, key) + while carry[0][0][1] == gi: + carry = sample_loop_fn(carry) + return carry class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): @@ -237,7 +243,7 @@ class PenalizingCausalTransformer(CausalTransformer): return sequence_index+1, initial_state _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) sample_key = initial_states[-1][0] - initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) + initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs)) return initial_states, sample_key return generate_initial_inner.apply(state["params"], key, ctx, ctx_length) self.generate_initial_xmap = jax.experimental.maps.xmap( @@ -281,7 +287,7 @@ class PenalizingCausalTransformer(CausalTransformer): # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time generated_index += 1 - carry[0][0] = (logits, generated_index, sequence_index, next_token, new_state) + carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state] carry[0].append(carry[0].pop(0)) return carry[0], return jax.lax.while_loop( @@ -312,21 +318,23 @@ class PenalizingCausalTransformer(CausalTransformer): numseqs_aux = batch_xmap(_numseqs_aux) sample_data = [ [ - jnp.pad(ctx, (0, params["seq"]), constant_values=pad_token_id), + np.pad(ctx[0], (0, params["seq"]), constant_values=pad_token_id), params["seq"], None, - jnp.empty((), dtype=jnp.uint32), + np.empty((), dtype=np.uint32), ] for _ in range(numseqs) ] repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) - sample_key = jax.device_put(sample_key[0, 0], cpu) + sample_key = np.asarray(sample_key[0, 0]) for _ in range(gen_length[0].item()): generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) for i in range(numseqs): - sample_data[i][2] = jax.device_put(generate_data[0][i][0, 0], cpu) - sample_data, sample_key = sample_jit(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + sample_data[i][2] = np.array(generate_data[0][i][0, 0], copy=True) + sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + for i in range(numseqs): + generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) return sample_data, sample_key @@ -368,7 +376,7 @@ def infer( soft_embeddings=soft_embeddings, )[0] for out in output: - samples.append(out[0][0, 0, params["seq"] : params["seq"] + gen_len]) + samples.append(out[0][params["seq"] : params["seq"] + gen_len]) return samples @@ -397,37 +405,37 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) jax.host_count = jax.process_count jax.host_id = jax.process_index - print("Connecting to your Colab instance's TPU", flush=True) - spinner = multiprocessing.Process(target=show_spinner, args=()) - spinner.start() - colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] - url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' - requests.post(url) - spinner.terminate() - print() - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] + while True: + print("Connecting to your Colab instance's TPU", flush=True) + spinner = multiprocessing.Process(target=show_spinner, args=()) + spinner.start() + colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] + url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' + requests.post(url) + spinner.terminate() + print() + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - cores_per_replica = params["cores_per_replica"] - seq = params["seq"] - params["optimizer"] = optax.scale(0) - mesh_shape = (1, cores_per_replica) - devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) - thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) - maps.thread_resources.env = thread_resources_env - tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') + cores_per_replica = params["cores_per_replica"] + seq = params["seq"] + params["optimizer"] = optax.scale(0) + mesh_shape = (1, cores_per_replica) + try: + devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) + except RuntimeError as e: + if "DEADLINE_EXCEEDED" not in str(e): + raise e + continue + thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) + maps.thread_resources.env = thread_resources_env + tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') + break global shard_xmap, batch_xmap shard_xmap = __shard_xmap() batch_xmap = __batch_xmap(shard_dim=cores_per_replica) - global cpu, sample_jit - cpu = jax.devices("cpu")[0] - sample_jit = jax.jit( - sample_jit, - device=cpu, - ) - global badwords # These are the tokens that we don't want the AI to ever write badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) From 932c393d6a418da6425646d90ffb7bfef9edacd6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 21:39:02 -0500 Subject: [PATCH 5/8] Add TPU support for dynamic WI scan and generation modifiers --- aiserver.py | 121 +++++++++++++++++++++++++++++++++++++-------- tpu_mtj_backend.py | 97 ++++++++++++++++++++---------------- 2 files changed, 155 insertions(+), 63 deletions(-) diff --git a/aiserver.py b/aiserver.py index 83df8c7e..96accfa3 100644 --- a/aiserver.py +++ b/aiserver.py @@ -22,7 +22,7 @@ import packaging import contextlib import traceback import threading -from typing import Any, Callable, TypeVar, Union, Dict, Set, List +from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List import requests import html @@ -1001,6 +1001,46 @@ else: ) return soft_tokens + def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]: + vars.generated_tkns += 1 + + assert len(excluded_world_info) == len(generated) + regeneration_required = vars.lua_koboldbridge.regeneration_required + halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt + vars.lua_koboldbridge.regeneration_required = False + + global past + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) + + scores_shape = scores.shape + scores_list = scores.tolist() + vars.lua_koboldbridge.logits = vars.lua_state.table() + for r, row in enumerate(scores_list): + vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) + vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + execute_genmod() + + scores = np.array( + tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), + dtype=scores.dtype, + ) + assert scores.shape == scores_shape + + if(not vars.dynamicscan or halt): + return excluded_world_info, regeneration_required, halt + + for i, t in enumerate(generated): + decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) + _, found = checkworldinfo(decoded, force_use_txt=True) + found -= excluded_world_info[i] + if(len(found) != 0): + regeneration_required = True + break + return excluded_world_info, regeneration_required, halt + # If we're running Colab or OAI, we still need a tokenizer. if(vars.model == "Colab"): from transformers import GPT2TokenizerFast @@ -1013,6 +1053,7 @@ else: print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) import tpu_mtj_backend + tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) @@ -1020,12 +1061,14 @@ else: soft_tokens = tpumtjgetsofttokens() threading.Thread( # Compile backend code in background target=tpu_mtj_backend.infer, - args=(np.uint32((23403, 727, 20185)),), + args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), kwargs={ "soft_embeddings": vars.sp, "soft_tokens": soft_tokens, + "use_callback": False, "gen_len": 1, "numseqs": vars.numseqs, + "excluded_world_info": list(set() for _ in range(vars.numseqs)), }, ).start() @@ -2890,32 +2933,68 @@ def sendtocolab(txt, min, max): # Send text to TPU mesh transformer backend #==================================================================# def tpumtjgenerate(txt, minimum, maximum, found_entries=None): + vars.generated_tkns = 0 + if(found_entries is None): found_entries = set() found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) + vars._actions = vars.actions + vars._prompt = vars.prompt + if(vars.dynamicscan): + vars._actions = vars._actions.copy() + # Submit input text to generator try: - if(vars.dynamicscan): - raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") - + context = np.tile(np.uint32(txt), (vars.numseqs, 1)) soft_tokens = tpumtjgetsofttokens() - genout = tpool.execute( - tpu_mtj_backend.infer, - np.uint32(txt), - gen_len = maximum-minimum+1, - temp=vars.temp, - top_p=vars.top_p, - top_k=vars.top_k, - tfs=vars.tfs, - numseqs=vars.numseqs, - repetition_penalty=vars.rep_pen, - soft_embeddings=vars.sp, - soft_tokens=soft_tokens, - ) + global past + past = np.empty((vars.numseqs, 0), dtype=np.uint32) + + while(True): + genout, n_generated, regeneration_required, halt = tpool.execute( + tpu_mtj_backend.infer, + context, + gen_len = maximum-minimum+1, + temp=vars.temp, + top_p=vars.top_p, + top_k=vars.top_k, + tfs=vars.tfs, + numseqs=vars.numseqs, + repetition_penalty=vars.rep_pen, + soft_embeddings=vars.sp, + soft_tokens=soft_tokens, + excluded_world_info=found_entries, + ) + + past = np.pad(past, ((0, 0), (0, n_generated))) + for r in range(vars.numseqs): + for c in range(vars.lua_koboldbridge.generated_cols): + assert vars.lua_koboldbridge.generated[r+1][c+1] is not None + past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] + + if(halt or not regeneration_required): + break + + encoded = [] + for i in range(vars.numseqs): + txt = tokenizer.decode(past[i]) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + found_entries[i].update(_found_entries) + txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) + encoded.append(np.array(txt, dtype=np.uint32)) + max_length = len(max(encoded, key=len)) + encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) + context = np.concatenate( + ( + encoded, + past, + ), + axis=-1, + ) except Exception as e: if(issubclass(type(e), lupa.LuaError)): @@ -2931,10 +3010,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr) set_aibusy(0) return - + for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) - vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) + vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i]) + genout = past execute_outmod() if(vars.lua_koboldbridge.regeneration_required): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index c2caf270..9cd49a12 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,5 +1,5 @@ import multiprocessing -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar import time import os @@ -20,6 +20,10 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor params: Dict[str, Any] = {} +def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]: + raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") + + def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) i = 0 @@ -235,13 +239,13 @@ class PenalizingCausalTransformer(CausalTransformer): def generate_initial_inner(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - def generate_initial_scan_fn(sequence_index, _): - _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) + def generate_initial_scan_fn(sequence_index, c): + _, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings) generated_index = config["seq"] # Add that information to generate_loop_fn's starting state initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state return sequence_index+1, initial_state - _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) + _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs) sample_key = initial_states[-1][0] initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs)) return initial_states, sample_key @@ -307,7 +311,8 @@ class PenalizingCausalTransformer(CausalTransformer): out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): + def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): + assert excluded_world_info is not None assert not return_logits assert gen_length.ndim == 1 assert soft_embeddings is not None @@ -318,24 +323,34 @@ class PenalizingCausalTransformer(CausalTransformer): numseqs_aux = batch_xmap(_numseqs_aux) sample_data = [ [ - np.pad(ctx[0], (0, params["seq"]), constant_values=pad_token_id), + np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id), params["seq"], None, np.empty((), dtype=np.uint32), ] - for _ in range(numseqs) + for i in range(numseqs) ] repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) + n_generated = 0 + regeneration_required = False + halt = False generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) sample_key = np.asarray(sample_key[0, 0]) - for _ in range(gen_length[0].item()): + while True: generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) for i in range(numseqs): - sample_data[i][2] = np.array(generate_data[0][i][0, 0], copy=True) + sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True) sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) for i in range(numseqs): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) - return sample_data, sample_key + n_generated += 1 + if use_callback: + excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated) + if regeneration_required or halt: + break + else: + break + return sample_data, n_generated, regeneration_required, halt def infer( @@ -349,15 +364,18 @@ def infer( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, -) -> List[str]: + excluded_world_info = None, + use_callback=True, +) -> Tuple[List[np.array], int, bool, bool]: + assert excluded_world_info is not None maps.thread_resources.env = thread_resources_env total_batch = 1 tokens = context if(soft_tokens is not None): - tokens = np.uint32(np.concatenate((soft_tokens, tokens))) - provided_ctx = tokens.shape[0] + tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1)) + provided_ctx = tokens.shape[-1] pad_amount = seq - provided_ctx - padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) + padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) samples = [] generator_params = { @@ -374,10 +392,12 @@ def infer( numseqs, generator_params, soft_embeddings=soft_embeddings, - )[0] - for out in output: + excluded_world_info=excluded_world_info, + use_callback=use_callback, + ) + for out in output[0]: samples.append(out[0][params["seq"] : params["seq"] + gen_len]) - return samples + return (samples,) + output[1:] def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: @@ -405,32 +425,25 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) jax.host_count = jax.process_count jax.host_id = jax.process_index - while True: - print("Connecting to your Colab instance's TPU", flush=True) - spinner = multiprocessing.Process(target=show_spinner, args=()) - spinner.start() - colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] - url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' - requests.post(url) - spinner.terminate() - print() - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] + print("Connecting to your Colab instance's TPU", flush=True) + spinner = multiprocessing.Process(target=show_spinner, args=()) + spinner.start() + colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] + url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' + requests.post(url) + spinner.terminate() + print() + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - cores_per_replica = params["cores_per_replica"] - seq = params["seq"] - params["optimizer"] = optax.scale(0) - mesh_shape = (1, cores_per_replica) - try: - devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) - except RuntimeError as e: - if "DEADLINE_EXCEEDED" not in str(e): - raise e - continue - thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) - maps.thread_resources.env = thread_resources_env - tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') - break + cores_per_replica = params["cores_per_replica"] + seq = params["seq"] + params["optimizer"] = optax.scale(0) + mesh_shape = (1, cores_per_replica) + devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) + thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) + maps.thread_resources.env = thread_resources_env + tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') global shard_xmap, batch_xmap shard_xmap = __shard_xmap() From e0fdce2cc6ea64a300908f8cf36bc7ee037c9fb6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 23:00:06 -0500 Subject: [PATCH 6/8] Fix TPU generation modifier --- aiserver.py | 30 +++++++++++++++++------------- tpu_mtj_backend.py | 15 ++++++++++++--- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/aiserver.py b/aiserver.py index 96accfa3..f232ff7c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1001,19 +1001,7 @@ else: ) return soft_tokens - def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]: - vars.generated_tkns += 1 - - assert len(excluded_world_info) == len(generated) - regeneration_required = vars.lua_koboldbridge.regeneration_required - halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt - vars.lua_koboldbridge.regeneration_required = False - - global past - - for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) - + def tpumtjgenerate_warper_callback(scores) -> "np.array": scores_shape = scores.shape scores_list = scores.tolist() vars.lua_koboldbridge.logits = vars.lua_state.table() @@ -1029,6 +1017,21 @@ else: ) assert scores.shape == scores_shape + return scores + + def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: + vars.generated_tkns += 1 + + assert len(excluded_world_info) == len(generated) + regeneration_required = vars.lua_koboldbridge.regeneration_required + halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt + vars.lua_koboldbridge.regeneration_required = False + + global past + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) + if(not vars.dynamicscan or halt): return excluded_world_info, regeneration_required, halt @@ -1054,6 +1057,7 @@ else: assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) import tpu_mtj_backend tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback + tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 9cd49a12..67196645 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor params: Dict[str, Any] = {} -def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]: +def warper_callback(logits) -> np.array: raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") +def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: + raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") + def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) @@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer): generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) for i in range(numseqs): sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True) + if use_callback: + logits = np.float32(tuple(d[2] for d in sample_data)) + logits = warper_callback(logits) + for i in range(numseqs): + sample_data[i][2] = logits[i] sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + n_generated += 1 for i in range(numseqs): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) - n_generated += 1 if use_callback: - excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated) + generated = np.uint32(tuple(d[0] for d in sample_data)) + excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info) if regeneration_required or halt: break else: From bdfde33e8ab83ba6997271b600217c20f355f4ba Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 23:13:55 -0500 Subject: [PATCH 7/8] Add an indicator for when dynamic WI scan is triggered in TPU Colabs --- aiserver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiserver.py b/aiserver.py index f232ff7c..e3ffba9d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2982,6 +2982,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): if(halt or not regeneration_required): break + print("(dynamic world info scanner triggered)") encoded = [] for i in range(vars.numseqs): From 877fa39b8a61424f355a14a26e236b5053f49f30 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 23:21:27 -0500 Subject: [PATCH 8/8] Change TPU regeneration indicator message --- aiserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index e3ffba9d..8328ac0c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2982,7 +2982,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): if(halt or not regeneration_required): break - print("(dynamic world info scanner triggered)") + print("(regeneration triggered)") encoded = [] for i in range(vars.numseqs):