diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 4a9833aa..f5d70e6b 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -172,36 +172,74 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a # probability distribution) - return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] + return jax.random.categorical(key, logits, -1).astype(jnp.uint32) pad_token_id = 50256 +def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): + numseqs = numseqs_aux.shape[0] + gi = data[0][1] + def sample_loop_fn(carry): + generated, generated_index, logits, _ = carry[0][0] + sample_key = carry[1] + # Get the pseudo-random number generator key that will + # be used by kobold_sample to randomly pick a token + sample_key, new_key = jax.random.split(sample_key, num=2) + # Apply repetition penalty to all tokens that are + # currently inside the "generated" array + logits = apply_repetition_penalty( + logits, + generated, + repetition_penalty + ) + # Remove any tokens in the badwords list by setting + # their logits to negative infinity which effectively + # makes their probabilities of being chosen zero + logits = logits.at[badwords].set(-jnp.inf) + # Use the sampler (kobold_sample) to pick one token + # based on the logits array as a 0D uint32 array + # (higher logit means higher probability of being + # picked, non-linearly) + next_token = kobold_sample( + sample_key, + logits, + **sampler_options, + ) + # Remember what token was picked + generated = generated.at[generated_index].set(next_token) + generated_index += 1 + # Re-pack the current sample_loop_fn's state so we can + # get back the same variables the next time + carry[0][0] = [generated, generated_index, logits, next_token] + carry[0].append(carry[0].pop(0)) + return carry[0], new_key + return jax.lax.while_loop( + lambda carry: carry[0][0][1] < gi, + sample_loop_fn, + (data, key), + ) + class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate_initial(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): numseqs = numseqs_aux.shape[0] + @hk.transform def generate_initial_inner(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) def generate_initial_scan_fn(sequence_index, _): _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) - # The "generated" array will contain the tokens from the - # context as well as the tokens picked by the sampler at - # each stage, padded with a bunch of 50256s, so we know - # which tokens have to be repetition penalized - generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens generated_index = config["seq"] # Add that information to generate_loop_fn's starting state - initial_state = (generated, generated_index, sequence_index) + initial_state + initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state return sequence_index+1, initial_state _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) sample_key = initial_states[-1][0] initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) return initial_states, sample_key - generate_initial_fn = hk.transform(generate_initial_inner).apply - return generate_initial_fn(state["params"], key, ctx, ctx_length) + return generate_initial_inner.apply(state["params"], key, ctx, ctx_length) self.generate_initial_xmap = jax.experimental.maps.xmap( fun=generate_initial, in_axes=( @@ -210,31 +248,23 @@ class PenalizingCausalTransformer(CausalTransformer): ["batch", ...], ["batch", ...], ["batch", ...], - ["batch", ...], - ["batch", ...], ["shard", ...], ), out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate_once(initial_states, sample_key, state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + def generate_once(data, state, numseqs_aux, soft_embeddings=None): numseqs = numseqs_aux.shape[0] - # These are the tokens that we don't want the AI to ever write - self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) - def generate_once_inner(context, ctx_length): - gi = initial_states[0][1] + @hk.without_apply_rng + @hk.transform + def generate_once_inner(): + gi = data[0][1] # Give the initial context to the transformer transformer = CausalTransformerShard(config) - # Get repetition penalty from the arguments - repetition_penalty = sampler_options.pop('repetition_penalty', None) # This is the main generation loop def generate_loop_fn(carry): # Unpack current generate_loop_fn state - generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] - sample_key = carry[1] - # Get the pseudo-random number generator key that will - # be used by kobold_sample to randomly pick a token - sample_key, new_key = jax.random.split(sample_key) + _, generated_index, sequence_index, next_token, decode_state = carry[0][0] # Give the context to the model and get the logits it # spits out # (a 2D array with 1 row and 50400 columns representing @@ -245,57 +275,27 @@ class PenalizingCausalTransformer(CausalTransformer): # Verify that logits does indeed have that many rows and # columns (if you get an error here, pray for mercy) assert logits.shape == (1, config["n_vocab"]) + assert logits.dtype == jnp.float32 # Flatten it into a 1D array to make it easier to use logits = logits[0] - # Apply repetition penalty to all tokens that are - # currently inside the "generated" array - if repetition_penalty is not None: - logits = apply_repetition_penalty( - logits, - generated, - repetition_penalty - ) - # Remove any tokens in the badwords list by setting - # their logits to negative infinity which effectively - # makes their probabilities of being chosen zero - logits = logits.at[self.badwords].set(-jnp.inf) - # Use the sampler (kobold_sample) to pick one token - # based on the logits array as a 1D array with 1 element - # (higher logit means higher probability of being - # picked, non-linearly) - next_token = kobold_sample( - sample_key, - logits, - **sampler_options, - ) - # Remember what token was picked - generated = generated.at[generated_index].set(next_token[0]) - generated_index += 1 # Re-pack the current generate_loop_fn's state so we can # get back the same variables the next time - carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) + generated_index += 1 + carry[0][0] = (logits, generated_index, sequence_index, next_token, new_state) carry[0].append(carry[0].pop(0)) - return carry[0], new_key - final_state = jax.lax.while_loop( + return carry[0], + return jax.lax.while_loop( lambda carry: carry[0][0][1] == gi, generate_loop_fn, - (initial_states, sample_key), + (data,), ) - return final_state - generate_once_fn = hk.transform(generate_once_inner).apply - return generate_once_fn(state["params"], key, ctx, ctx_length) + return generate_once_inner.apply(state["params"]) self.generate_once_xmap = jax.experimental.maps.xmap( fun=generate_once, in_axes=( - ["shard", "batch", ...], ["shard", "batch", ...], ["shard", ...], ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], - ["batch", ...], ["shard", ...], ), out_axes=["shard", "batch", ...], @@ -304,23 +304,30 @@ class PenalizingCausalTransformer(CausalTransformer): def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits assert gen_length.ndim == 1 + assert soft_embeddings is not None key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size - xargs = [ - self.state, - batch_xmap(jnp.array(key.take(batch_size))), - batch_xmap(ctx), - batch_xmap(np.array(ctx_length, dtype=np.uint32)), - batch_xmap(np.array(gen_length, dtype=np.uint32)), - np.empty((batch_size, numseqs), dtype=np.uint8), - batch_xmap(sampler_options), - shard_xmap(soft_embeddings), + _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) + numseqs_aux = batch_xmap(_numseqs_aux) + sample_data = [ + [ + jnp.pad(ctx, (0, params["seq"]), constant_values=pad_token_id), + params["seq"], + None, + jnp.empty((), dtype=jnp.uint32), + ] + for _ in range(numseqs) ] - initial_state, sample_key = self.generate_initial_xmap(*xargs) - for i in range(gen_length[0]): - initial_state, sample_key = self.generate_once_xmap(initial_state, sample_key, *xargs) - return initial_state, sample_key + repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) + generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) + sample_key = jax.device_put(sample_key[0, 0], cpu) + for _ in range(gen_length[0].item()): + generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) + for i in range(numseqs): + sample_data[i][2] = jax.device_put(generate_data[0][i][0, 0], cpu) + sample_data, sample_key = sample_jit(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + return sample_data, sample_key def infer( @@ -345,23 +352,23 @@ def infer( padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) samples = [] - batched_generator_params = { - "temp": temp * np.ones(total_batch), - "top_p": top_p * np.ones(total_batch), - "tfs": tfs * np.ones(total_batch), - "repetition_penalty": repetition_penalty * np.ones(total_batch), - "top_k": np.full(total_batch, top_k, dtype=np.uint32) + generator_params = { + "temp": float(temp), + "top_p": float(top_p), + "tfs": float(tfs), + "repetition_penalty": float(repetition_penalty), + "top_k": int(top_k), } output = network.generate( batched_tokens, np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * gen_len, numseqs, - batched_generator_params, + generator_params, soft_embeddings=soft_embeddings, )[0] - for o in output: - samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) + for out in output: + samples.append(out[0][0, 0, params["seq"] : params["seq"] + gen_len]) return samples @@ -414,6 +421,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) shard_xmap = __shard_xmap() batch_xmap = __batch_xmap(shard_dim=cores_per_replica) + global cpu, sample_jit + cpu = jax.devices("cpu")[0] + sample_jit = jax.jit( + sample_jit, + device=cpu, + ) + + global badwords + # These are the tokens that we don't want the AI to ever write + badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) + if not path.endswith("/"): path += "/"