diff --git a/aiserver.py b/aiserver.py index 28da1d21..6b41d72c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1802,12 +1802,25 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") soft_tokens = None - if(vars.sp is not None): - soft_tokens = np.arange( - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], - tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, - dtype=np.uint32 + if(vars.sp is None): + global np + if 'np' not in globals(): + import numpy as np + tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32) + rows = tensor.shape[0] + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params["d_model"], ) + vars.sp = tensor + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) genout = tpu_mtj_backend.infer( txt, @@ -2676,7 +2689,7 @@ def spRequest(filename): if(vars.model in ("TPUMeshTransformerGPTJ",)): rows = tensor.shape[0] - padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"]) + padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = tensor.reshape( tpu_mtj_backend.params["cores_per_replica"], diff --git a/requirements_mtj.txt b/requirements_mtj.txt index 9d86ccae..d82bfcb0 100644 --- a/requirements_mtj.txt +++ b/requirements_mtj.txt @@ -4,11 +4,11 @@ requests optax >= 0.0.5, <= 0.0.9 dm-haiku ray[default] -jax == 0.2.12 +jax == 0.2.21 transformers progressbar2 git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck flask Flask-SocketIO flask-cloudflared >= 0.0.5 -flask-ngrok \ No newline at end of file +flask-ngrok diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 5bfb8b57..cbf6b499 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -30,7 +30,7 @@ def show_spinner(): def apply_repetition_penalty(logits, tokens, repetition_penalty): ''' - This gets called by generate_scan_fn to apply repetition penalty + This gets called by generate_loop_fn to apply repetition penalty to the 1D array logits using the provided 1D array of tokens to penalize ''' # Make a new array with the same length as the tokens array but with @@ -52,9 +52,9 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): # positions in the logits array return logits.at[tokens].set(penalty_logits) -def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' - This gets called by generate_scan_fn to apply a series of 4 filters + This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before picking one token using the modified logits ''' @@ -147,7 +147,7 @@ def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None + return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] pad_token_id = 50256 @@ -155,27 +155,34 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None): - gen_length = self.gen_length + def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + numseqs = numseqs_aux.shape[0] # These are the tokens that we don't want the AI to ever write self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) - def generate_sample(context, ctx_length, aux): + def generate_sample(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) - # The "generated" array will contain the tokens from the - # context as well as the tokens picked by the sampler at - # each stage, padded with a bunch of 50256s, so we know - # which tokens have to be repetition penalized - generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens - generated_index = config["seq"] - # Add that information to generate_scan_fn's starting state - initial_state = (generated, generated_index) + initial_state + def generate_initial_scan_fn(sequence_index, _): + _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) + # The "generated" array will contain the tokens from the + # context as well as the tokens picked by the sampler at + # each stage, padded with a bunch of 50256s, so we know + # which tokens have to be repetition penalized + generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens + generated_index = config["seq"] + # Add that information to generate_loop_fn's starting state + initial_state = (generated, generated_index, sequence_index) + initial_state + return sequence_index+1, initial_state + _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) + sample_key = initial_states[-1][0] + initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) # Get repetition penalty from the arguments repetition_penalty = sampler_options.pop('repetition_penalty', None) - def generate_scan_fn(carry, sampler_input): - # Unpack current generate_scan_fn state - generated, generated_index, next_token, decode_state, sample_key = carry + # This is the main generation loop + def generate_loop_fn(carry): + # Unpack current generate_loop_fn state + generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] + sample_key = carry[1] # Get the pseudo-random number generator key that will # be used by kobold_sample to randomly pick a token sample_key, new_key = jax.random.split(sample_key) @@ -207,56 +214,54 @@ class PenalizingCausalTransformer(CausalTransformer): # based on the logits array as a 1D array with 1 element # (higher logit means higher probability of being # picked, non-linearly) - next_token, sample_info = kobold_sample( + next_token = kobold_sample( sample_key, logits, - sampler_input, **sampler_options, ) - # Remember what token was picked so we can repetition - # penalize it next time + # Remember what token was picked generated = generated.at[generated_index].set(next_token[0]) generated_index += 1 - # self.return_logits isn't used in this program, but - # for the sake of compatibility... - if self.return_logits: - output = (next_token, sample_info, logits[jnp.newaxis]) - else: - output = (next_token, sample_info) - # Re-pack the current generate_scan_fn's state so we can + # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time - new_carry = (generated, generated_index, next_token, new_state, new_key) - return new_carry, output - # jax.lax.scan is a function that calls generate_scan_fn - # gen_length times, each time passing a state object from - # its return value (new_carry) back into one of the - # function's arguments (carry), and of course gathering the - # token it generates each time into the "outputs" array; - # we have to use jax.lax.scan instead of a normal loop - # because of JAX's JIT-compilation shenanigans - final_state, outputs = jax.lax.scan( - generate_scan_fn, - initial_state, - xs=aux, - length=gen_length, + carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) + carry[0].append(carry[0].pop(0)) + return carry[0], new_key + final_state = jax.lax.while_loop( + lambda carry: carry[0][0][1] - config["seq"] < gen_length, + generate_loop_fn, + (initial_states, sample_key), ) - return final_state, outputs + return final_state generate_fn = hk.transform(generate_sample).apply - return generate_fn(state["params"], key, ctx, ctx_length, aux) - self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}) - def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None): + return generate_fn(state["params"], key, ctx, ctx_length) + self.generate_xmap = jax.experimental.maps.xmap( + fun=generate, + in_axes=( + ["shard", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["shard", ...], + ), + out_axes=["shard", "batch", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) + def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): + assert not return_logits key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] - aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32) - self.gen_length = gen_length self.batch_size = batch_size - self.return_logits = return_logits return self.generate_xmap( self.state, jnp.array(key.take(batch_size)), ctx, np.array(ctx_length, dtype=np.uint32), - aux, + np.array(gen_length, dtype=np.uint32), + np.empty((batch_size, numseqs), dtype=np.uint8), sampler_options, soft_embeddings, ) @@ -275,7 +280,7 @@ def infer( soft_tokens: Optional[np.array] = None, ) -> List[str]: maps.thread_resources.env = thread_resources_env - total_batch = numseqs + total_batch = 1 tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True)) if(soft_tokens is not None): tokens = np.uint32(np.concatenate((soft_tokens, tokens))) @@ -283,7 +288,6 @@ def infer( pad_amount = seq - provided_ctx padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) - length = np.ones(total_batch, dtype=np.uint32) * provided_ctx samples = [] batched_generator_params = { "temp": temp * np.ones(total_batch), @@ -294,14 +298,14 @@ def infer( } output = network.generate( batched_tokens, - length, - gen_len, + np.ones(total_batch, dtype=np.uint32) * provided_ctx, + np.ones(total_batch, dtype=np.uint32) * gen_len, + numseqs, batched_generator_params, soft_embeddings=soft_embeddings, - ) - decoded_tokens = output[1][0] - for o in decoded_tokens[:, :, 0]: - samples.append(tokenizer.decode(o)) + )[0] + for o in output: + samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len])) return samples @@ -326,6 +330,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) if param not in params: params[param] = default_params[param] + # Disable JAX warnings about these two functions having been renamed + jax.host_count = jax.process_count + jax.host_id = jax.process_index + print("Connecting to your Colab instance's TPU", flush=True) spinner = multiprocessing.Process(target=show_spinner, args=()) spinner.start() @@ -342,7 +350,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) params["optimizer"] = optax.scale(0) mesh_shape = (1, cores_per_replica) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) - thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) + thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) maps.thread_resources.env = thread_resources_env tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')