From 3c349e6aafd7f03578c14c30dc0e608013c78e63 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 30 Nov 2021 10:13:02 -0500 Subject: [PATCH 1/4] Modify TPU backend code to support JAX 0.2.21 The original one supported versions of JAX up to 0.2.12, and possibly also some earlier versions. This new code supports exclusively JAX 0.2.21 and does not work with any earlier or later versions of JAX. However, this new code benefits from not needing to recompile when changing "Amount To Generate" and also from supporting stopping generation early, which makes an implementation of Dynamic World Info Scan finally possible. --- requirements_mtj.txt | 4 +- tpu_mtj_backend.py | 90 +++++++++++++++++++++----------------------- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 9d86ccae..d82bfcb0 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -4,11 +4,11 @@ requests optax >= 0.0.5, <= 0.0.9 dm-haiku ray[default] -jax == 0.2.12 +jax == 0.2.21 transformers progressbar2 git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask Flask-SocketIO flask-cloudflared >= 0.0.5 -flask-ngrok \ No newline at end of file +flask-ngrok diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 5bfb8b57..d86fcc03 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -30,7 +30,7 @@ def show_spinner(): def apply_repetition_penalty(logits, tokens, repetition_penalty): ''' - This gets called by generate_scan_fn to apply repetition penalty + This gets called by generate_loop_fn to apply repetition penalty to the 1D array logits using the provided 1D array of tokens to penalize ''' # Make a new array with the same length as the tokens array but with @@ -52,9 +52,9 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' - This gets called by generate_scan_fn to apply a series of 4 filters + This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before picking one token using the modified logits ''' @@ -147,7 +147,7 @@ def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None + return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] pad_token_id = 50256 @@ -155,11 +155,10 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None): - gen_length = self.gen_length + def generate(state, key, ctx, ctx_length, gen_length, sampler_options, soft_embeddings=None): # These are the tokens that we don't want the AI to ever write self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) - def generate_sample(context, ctx_length, aux): + def generate_sample(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) @@ -167,14 +166,14 @@ class PenalizingCausalTransformer(CausalTransformer): # context as well as the tokens picked by the sampler at # each stage, padded with a bunch of 50256s, so we know # which tokens have to be repetition penalized - generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens + generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens generated_index = config["seq"] - # Add that information to generate_scan_fn's starting state + # Add that information to generate_loop_fn's starting state initial_state = (generated, generated_index) + initial_state # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) - def generate_scan_fn(carry, sampler_input): - # Unpack current generate_scan_fn state + def generate_loop_fn(carry): + # Unpack current generate_loop_fn state generated, generated_index, next_token, decode_state, sample_key = carry # Get the pseudo-random number generator key that will # be used by kobold_sample to randomly pick a token @@ -207,56 +206,51 @@ class PenalizingCausalTransformer(CausalTransformer): # based on the logits array as a 1D array with 1 element # (higher logit means higher probability of being # picked, non-linearly) - next_token, sample_info = kobold_sample( + next_token = kobold_sample( sample_key, logits, - sampler_input, **sampler_options, ) - # Remember what token was picked so we can repetition - # penalize it next time + # Remember what token was picked generated = generated.at[generated_index].set(next_token[0]) generated_index += 1 - # self.return_logits isn't used in this program, but - # for the sake of compatibility... - if self.return_logits: - output = (next_token, sample_info, logits[jnp.newaxis]) - else: - output = (next_token, sample_info) - # Re-pack the current generate_scan_fn's state so we can + # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time new_carry = (generated, generated_index, next_token, new_state, new_key) - return new_carry, output - # jax.lax.scan is a function that calls generate_scan_fn - # gen_length times, each time passing a state object from - # its return value (new_carry) back into one of the - # function's arguments (carry), and of course gathering the - # token it generates each time into the "outputs" array; - # we have to use jax.lax.scan instead of a normal loop - # because of JAX's JIT-compilation shenanigans - final_state, outputs = jax.lax.scan( - generate_scan_fn, + return new_carry + final_state = jax.lax.while_loop( + lambda carry: carry[1] - config["seq"] < gen_length, + generate_loop_fn, initial_state, - xs=aux, - length=gen_length, ) - return final_state, outputs + return final_state generate_fn = hk.transform(generate_sample).apply - return generate_fn(state["params"], key, ctx, ctx_length, aux) - self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}) + return generate_fn(state["params"], key, ctx, ctx_length) + self.generate_xmap = jax.experimental.maps.xmap( + fun=generate, + in_axes=( + ["shard", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["shard", ...], + ), + out_axes=["shard", "batch", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None): + assert not return_logits key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] - aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32) - self.gen_length = gen_length self.batch_size = batch_size - self.return_logits = return_logits return self.generate_xmap( self.state, jnp.array(key.take(batch_size)), ctx, np.array(ctx_length, dtype=np.uint32), - aux, + np.array(gen_length, dtype=np.uint32), sampler_options, soft_embeddings, ) @@ -283,7 +277,6 @@ def infer( pad_amount = seq - provided_ctx padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) - length = np.ones(total_batch, dtype=np.uint32) * provided_ctx samples = [] batched_generator_params = { "temp": temp * np.ones(total_batch), @@ -294,13 +287,12 @@ def infer( } output = network.generate( batched_tokens, - length, - gen_len, + np.ones(total_batch, dtype=np.uint32) * provided_ctx, + np.ones(total_batch, dtype=np.uint32) * gen_len, batched_generator_params, soft_embeddings=soft_embeddings, ) - decoded_tokens = output[1][0] - for o in decoded_tokens[:, :, 0]: + for o in output[0][0, :, params["seq"] : params["seq"] + gen_len]: samples.append(tokenizer.decode(o)) return samples @@ -326,6 +318,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) if param not in params: params[param] = default_params[param] + # Disable JAX warnings about these two functions having been renamed + jax.host_count = jax.process_count + jax.host_id = jax.process_index + print("Connecting to your Colab instance's TPU", flush=True) spinner = multiprocessing.Process(target=show_spinner, args=()) spinner.start() @@ -342,7 +338,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) params["optimizer"] = optax.scale(0) mesh_shape = (1, cores_per_replica) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) - thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) + thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') From c1e7c1643f8db61a1b297dbcbe0c74bf04fbd90c Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 30 Nov 2021 14:06:46 -0500 Subject: [PATCH 2/4] Fix unbound axis error in tpu_mtj_backend.py when `numseqs > 1` --- tpu_mtj_backend.py | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index d86fcc03..aa64da32 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -155,26 +155,30 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate(state, key, ctx, ctx_length, gen_length, sampler_options, soft_embeddings=None): + def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + numseqs = numseqs_aux.shape[0] # These are the tokens that we don't want the AI to ever write self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) def generate_sample(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) - # The "generated" array will contain the tokens from the - # context as well as the tokens picked by the sampler at - # each stage, padded with a bunch of 50256s, so we know - # which tokens have to be repetition penalized - generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens - generated_index = config["seq"] - # Add that information to generate_loop_fn's starting state - initial_state = (generated, generated_index) + initial_state + initial_states = [] + for sequence_index in range(numseqs): + _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) + # The "generated" array will contain the tokens from the + # context as well as the tokens picked by the sampler at + # each stage, padded with a bunch of 50256s, so we know + # which tokens have to be repetition penalized + generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens + generated_index = config["seq"] + # Add that information to generate_loop_fn's starting state + initial_state = (generated, generated_index, sequence_index) + initial_state + initial_states.append(initial_state) # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) def generate_loop_fn(carry): # Unpack current generate_loop_fn state - generated, generated_index, next_token, decode_state, sample_key = carry + generated, generated_index, sequence_index, next_token, decode_state, sample_key = carry[0] # Get the pseudo-random number generator key that will # be used by kobold_sample to randomly pick a token sample_key, new_key = jax.random.split(sample_key) @@ -216,12 +220,13 @@ class PenalizingCausalTransformer(CausalTransformer): generated_index += 1 # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time - new_carry = (generated, generated_index, next_token, new_state, new_key) - return new_carry + carry[0] = (generated, generated_index, sequence_index, next_token, new_state, new_key) + carry.append(carry.pop(0)) + return carry final_state = jax.lax.while_loop( - lambda carry: carry[1] - config["seq"] < gen_length, + lambda carry: carry[0][1] - config["seq"] < gen_length, generate_loop_fn, - initial_state, + initial_states, ) return final_state generate_fn = hk.transform(generate_sample).apply @@ -235,12 +240,13 @@ class PenalizingCausalTransformer(CausalTransformer): ["batch", ...], ["batch", ...], ["batch", ...], + ["batch", ...], ["shard", ...], ), out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None): + def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] @@ -251,6 +257,7 @@ class PenalizingCausalTransformer(CausalTransformer): ctx, np.array(ctx_length, dtype=np.uint32), np.array(gen_length, dtype=np.uint32), + np.empty((batch_size, numseqs), dtype=np.uint8), sampler_options, soft_embeddings, ) @@ -269,7 +276,7 @@ def infer( soft_tokens: Optional[np.array] = None, ) -> List[str]: maps.thread_resources.env = thread_resources_env - total_batch = numseqs + total_batch = 1 tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True)) if(soft_tokens is not None): tokens = np.uint32(np.concatenate((soft_tokens, tokens))) @@ -289,11 +296,12 @@ def infer( batched_tokens, np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * gen_len, + numseqs, batched_generator_params, soft_embeddings=soft_embeddings, ) - for o in output[0][0, :, params["seq"] : params["seq"] + gen_len]: - samples.append(tokenizer.decode(o)) + for o in output: + samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len])) return samples From d2d338d3141c7b2569742ddc4d3b14bacfba8592 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 30 Nov 2021 19:22:40 -0500 Subject: [PATCH 3/4] Improve TPU backend compilation times with `numseqs > 1` A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX only compiles the `transformer.generate_initial` function one time instead of `numseqs` times. This is because JAX unrolls Python built-in loops like `for`. The compilation times should now be about the same as they were before the upgrade to JAX 0.2.21. --- tpu_mtj_backend.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index aa64da32..cbf6b499 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -162,8 +162,7 @@ class PenalizingCausalTransformer(CausalTransformer): def generate_sample(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - initial_states = [] - for sequence_index in range(numseqs): + def generate_initial_scan_fn(sequence_index, _): _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) # The "generated" array will contain the tokens from the # context as well as the tokens picked by the sampler at @@ -173,12 +172,17 @@ class PenalizingCausalTransformer(CausalTransformer): generated_index = config["seq"] # Add that information to generate_loop_fn's starting state initial_state = (generated, generated_index, sequence_index) + initial_state - initial_states.append(initial_state) + return sequence_index+1, initial_state + _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) + sample_key = initial_states[-1][0] + initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) + # This is the main generation loop def generate_loop_fn(carry): # Unpack current generate_loop_fn state - generated, generated_index, sequence_index, next_token, decode_state, sample_key = carry[0] + generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] + sample_key = carry[1] # Get the pseudo-random number generator key that will # be used by kobold_sample to randomly pick a token sample_key, new_key = jax.random.split(sample_key) @@ -220,13 +224,13 @@ class PenalizingCausalTransformer(CausalTransformer): generated_index += 1 # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time - carry[0] = (generated, generated_index, sequence_index, next_token, new_state, new_key) - carry.append(carry.pop(0)) - return carry + carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) + carry[0].append(carry[0].pop(0)) + return carry[0], new_key final_state = jax.lax.while_loop( - lambda carry: carry[0][1] - config["seq"] < gen_length, + lambda carry: carry[0][0][1] - config["seq"] < gen_length, generate_loop_fn, - initial_states, + (initial_states, sample_key), ) return final_state generate_fn = hk.transform(generate_sample).apply @@ -299,7 +303,7 @@ def infer( numseqs, batched_generator_params, soft_embeddings=soft_embeddings, - ) + )[0] for o in output: samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len])) return samples From 150ce033c93a076ad5cb0b6b0aa05d1574821b63 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 5 Dec 2021 02:49:15 -0500 Subject: [PATCH 4/4] TPU backend no longer needs to recompile after changing softprompt --- aiserver.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/aiserver.py b/aiserver.py index 28da1d21..6b41d72c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1802,12 +1802,25 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") soft_tokens = None - if(vars.sp is not None): - soft_tokens = np.arange( - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, - dtype=np.uint32 + if(vars.sp is None): + global np + if 'np' not in globals(): + import numpy as np + tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + rows = tensor.shape[0] + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params["d_model"], ) + vars.sp = tensor + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) genout = tpu_mtj_backend.infer( txt, @@ -2676,7 +2689,7 @@ def spRequest(filename): if(vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] - padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"]) + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"],