mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Move sampling into a jax.jitted function
				
					
				
			This commit is contained in:
		| @@ -172,36 +172,74 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||
|     # Finally, pick one token using the softmax thingy again (it gives | ||||
|     # an array whose elements sum to 1 so it can be used nicely as a | ||||
|     # probability distribution) | ||||
|     return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] | ||||
|     return jax.random.categorical(key, logits, -1).astype(jnp.uint32) | ||||
|  | ||||
| pad_token_id = 50256 | ||||
|  | ||||
| def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): | ||||
|     numseqs = numseqs_aux.shape[0] | ||||
|     gi = data[0][1] | ||||
|     def sample_loop_fn(carry): | ||||
|         generated, generated_index, logits, _ = carry[0][0] | ||||
|         sample_key = carry[1] | ||||
|         # Get the pseudo-random number generator key that will | ||||
|         # be used by kobold_sample to randomly pick a token | ||||
|         sample_key, new_key = jax.random.split(sample_key, num=2) | ||||
|         # Apply repetition penalty to all tokens that are | ||||
|         # currently inside the "generated" array | ||||
|         logits = apply_repetition_penalty( | ||||
|             logits, | ||||
|             generated, | ||||
|             repetition_penalty | ||||
|         ) | ||||
|         # Remove any tokens in the badwords list by setting | ||||
|         # their logits to negative infinity which effectively | ||||
|         # makes their probabilities of being chosen zero | ||||
|         logits = logits.at[badwords].set(-jnp.inf) | ||||
|         # Use the sampler (kobold_sample) to pick one token | ||||
|         # based on the logits array as a 0D uint32 array | ||||
|         # (higher logit means higher probability of being | ||||
|         # picked, non-linearly) | ||||
|         next_token = kobold_sample( | ||||
|             sample_key, | ||||
|             logits, | ||||
|             **sampler_options, | ||||
|         ) | ||||
|         # Remember what token was picked | ||||
|         generated = generated.at[generated_index].set(next_token) | ||||
|         generated_index += 1 | ||||
|         # Re-pack the current sample_loop_fn's state so we can | ||||
|         # get back the same variables the next time | ||||
|         carry[0][0] = [generated, generated_index, logits, next_token] | ||||
|         carry[0].append(carry[0].pop(0)) | ||||
|         return carry[0], new_key | ||||
|     return jax.lax.while_loop( | ||||
|         lambda carry: carry[0][0][1] < gi, | ||||
|         sample_loop_fn, | ||||
|         (data, key), | ||||
|     ) | ||||
|  | ||||
| class PenalizingCausalTransformer(CausalTransformer): | ||||
|     def __init__(self, config): | ||||
|         # Initialize | ||||
|         super().__init__(config) | ||||
|         def generate_initial(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): | ||||
|         def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): | ||||
|             numseqs = numseqs_aux.shape[0] | ||||
|             @hk.transform | ||||
|             def generate_initial_inner(context, ctx_length): | ||||
|                 # Give the initial context to the transformer | ||||
|                 transformer = CausalTransformerShard(config) | ||||
|                 def generate_initial_scan_fn(sequence_index, _): | ||||
|                     _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) | ||||
|                     # The "generated" array will contain the tokens from the | ||||
|                     # context as well as the tokens picked by the sampler at | ||||
|                     # each stage, padded with a bunch of 50256s, so we know | ||||
|                     # which tokens have to be repetition penalized | ||||
|                     generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id)  # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens | ||||
|                     generated_index = config["seq"] | ||||
|                     # Add that information to generate_loop_fn's starting state | ||||
|                     initial_state = (generated, generated_index, sequence_index) + initial_state | ||||
|                     initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state | ||||
|                     return sequence_index+1, initial_state | ||||
|                 _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) | ||||
|                 sample_key = initial_states[-1][0] | ||||
|                 initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) | ||||
|                 return initial_states, sample_key | ||||
|             generate_initial_fn = hk.transform(generate_initial_inner).apply | ||||
|             return generate_initial_fn(state["params"], key, ctx, ctx_length) | ||||
|             return generate_initial_inner.apply(state["params"], key, ctx, ctx_length) | ||||
|         self.generate_initial_xmap = jax.experimental.maps.xmap( | ||||
|             fun=generate_initial, | ||||
|             in_axes=( | ||||
| @@ -210,31 +248,23 @@ class PenalizingCausalTransformer(CausalTransformer): | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["shard", ...], | ||||
|             ), | ||||
|             out_axes=["shard", "batch", ...], | ||||
|             axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||
|         ) | ||||
|         def generate_once(initial_states, sample_key, state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): | ||||
|         def generate_once(data, state, numseqs_aux, soft_embeddings=None): | ||||
|             numseqs = numseqs_aux.shape[0] | ||||
|             # These are the tokens that we don't want the AI to ever write | ||||
|             self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) | ||||
|             def generate_once_inner(context, ctx_length): | ||||
|                 gi = initial_states[0][1] | ||||
|             @hk.without_apply_rng | ||||
|             @hk.transform | ||||
|             def generate_once_inner(): | ||||
|                 gi = data[0][1] | ||||
|                 # Give the initial context to the transformer | ||||
|                 transformer = CausalTransformerShard(config) | ||||
|                 # Get repetition penalty from the arguments | ||||
|                 repetition_penalty = sampler_options.pop('repetition_penalty', None) | ||||
|                 # This is the main generation loop | ||||
|                 def generate_loop_fn(carry): | ||||
|                     # Unpack current generate_loop_fn state | ||||
|                     generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] | ||||
|                     sample_key = carry[1] | ||||
|                     # Get the pseudo-random number generator key that will | ||||
|                     # be used by kobold_sample to randomly pick a token | ||||
|                     sample_key, new_key = jax.random.split(sample_key) | ||||
|                     _, generated_index, sequence_index, next_token, decode_state = carry[0][0] | ||||
|                     # Give the context to the model and get the logits it | ||||
|                     # spits out | ||||
|                     # (a 2D array with 1 row and 50400 columns representing | ||||
| @@ -245,57 +275,27 @@ class PenalizingCausalTransformer(CausalTransformer): | ||||
|                     # Verify that logits does indeed have that many rows and | ||||
|                     # columns (if you get an error here, pray for mercy) | ||||
|                     assert logits.shape == (1, config["n_vocab"]) | ||||
|                     assert logits.dtype == jnp.float32 | ||||
|                     # Flatten it into a 1D array to make it easier to use | ||||
|                     logits = logits[0] | ||||
|                     # Apply repetition penalty to all tokens that are | ||||
|                     # currently inside the "generated" array | ||||
|                     if repetition_penalty is not None: | ||||
|                         logits = apply_repetition_penalty( | ||||
|                             logits, | ||||
|                             generated, | ||||
|                             repetition_penalty | ||||
|                         ) | ||||
|                     # Remove any tokens in the badwords list by setting | ||||
|                     # their logits to negative infinity which effectively | ||||
|                     # makes their probabilities of being chosen zero | ||||
|                     logits = logits.at[self.badwords].set(-jnp.inf) | ||||
|                     # Use the sampler (kobold_sample) to pick one token | ||||
|                     # based on the logits array as a 1D array with 1 element | ||||
|                     # (higher logit means higher probability of being | ||||
|                     # picked, non-linearly) | ||||
|                     next_token = kobold_sample( | ||||
|                         sample_key, | ||||
|                         logits, | ||||
|                         **sampler_options, | ||||
|                     ) | ||||
|                     # Remember what token was picked | ||||
|                     generated = generated.at[generated_index].set(next_token[0]) | ||||
|                     generated_index += 1 | ||||
|                     # Re-pack the current generate_loop_fn's state so we can | ||||
|                     # get back the same variables the next time | ||||
|                     carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) | ||||
|                     generated_index += 1 | ||||
|                     carry[0][0] = (logits, generated_index, sequence_index, next_token, new_state) | ||||
|                     carry[0].append(carry[0].pop(0)) | ||||
|                     return carry[0], new_key | ||||
|                 final_state = jax.lax.while_loop( | ||||
|                     return carry[0], | ||||
|                 return jax.lax.while_loop( | ||||
|                     lambda carry: carry[0][0][1] == gi, | ||||
|                     generate_loop_fn, | ||||
|                     (initial_states, sample_key), | ||||
|                     (data,), | ||||
|                 ) | ||||
|                 return final_state | ||||
|             generate_once_fn = hk.transform(generate_once_inner).apply | ||||
|             return generate_once_fn(state["params"], key, ctx, ctx_length) | ||||
|             return generate_once_inner.apply(state["params"]) | ||||
|         self.generate_once_xmap = jax.experimental.maps.xmap( | ||||
|             fun=generate_once, | ||||
|             in_axes=( | ||||
|                 ["shard", "batch", ...], | ||||
|                 ["shard", "batch", ...], | ||||
|                 ["shard", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["shard", ...], | ||||
|             ), | ||||
|             out_axes=["shard", "batch", ...], | ||||
| @@ -304,23 +304,30 @@ class PenalizingCausalTransformer(CausalTransformer): | ||||
|     def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): | ||||
|         assert not return_logits | ||||
|         assert gen_length.ndim == 1 | ||||
|         assert soft_embeddings is not None | ||||
|         key = hk.PRNGSequence(random.randint(0, 2 ** 60)) | ||||
|         batch_size = ctx.shape[0] | ||||
|         self.batch_size = batch_size | ||||
|         xargs = [ | ||||
|             self.state, | ||||
|             batch_xmap(jnp.array(key.take(batch_size))), | ||||
|             batch_xmap(ctx), | ||||
|             batch_xmap(np.array(ctx_length, dtype=np.uint32)), | ||||
|             batch_xmap(np.array(gen_length, dtype=np.uint32)), | ||||
|             np.empty((batch_size, numseqs), dtype=np.uint8), | ||||
|             batch_xmap(sampler_options), | ||||
|             shard_xmap(soft_embeddings), | ||||
|         _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) | ||||
|         numseqs_aux = batch_xmap(_numseqs_aux) | ||||
|         sample_data = [ | ||||
|             [ | ||||
|                 jnp.pad(ctx, (0, params["seq"]), constant_values=pad_token_id), | ||||
|                 params["seq"], | ||||
|                 None, | ||||
|                 jnp.empty((), dtype=jnp.uint32), | ||||
|             ] | ||||
|         initial_state, sample_key = self.generate_initial_xmap(*xargs) | ||||
|         for i in range(gen_length[0]): | ||||
|             initial_state, sample_key = self.generate_once_xmap(initial_state, sample_key, *xargs) | ||||
|         return initial_state, sample_key | ||||
|             for _ in range(numseqs) | ||||
|         ] | ||||
|         repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) | ||||
|         generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) | ||||
|         sample_key = jax.device_put(sample_key[0, 0], cpu) | ||||
|         for _ in range(gen_length[0].item()): | ||||
|             generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) | ||||
|             for i in range(numseqs): | ||||
|                 sample_data[i][2] = jax.device_put(generate_data[0][i][0, 0], cpu) | ||||
|             sample_data, sample_key = sample_jit(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) | ||||
|         return sample_data, sample_key | ||||
|  | ||||
|  | ||||
| def infer( | ||||
| @@ -345,23 +352,23 @@ def infer( | ||||
|     padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) | ||||
|     batched_tokens = np.array([padded_tokens] * total_batch) | ||||
|     samples = [] | ||||
|     batched_generator_params = { | ||||
|         "temp": temp * np.ones(total_batch), | ||||
|         "top_p": top_p * np.ones(total_batch), | ||||
|         "tfs": tfs * np.ones(total_batch), | ||||
|         "repetition_penalty": repetition_penalty * np.ones(total_batch), | ||||
|         "top_k": np.full(total_batch, top_k, dtype=np.uint32) | ||||
|     generator_params = { | ||||
|         "temp": float(temp), | ||||
|         "top_p": float(top_p), | ||||
|         "tfs": float(tfs), | ||||
|         "repetition_penalty": float(repetition_penalty), | ||||
|         "top_k": int(top_k), | ||||
|     } | ||||
|     output = network.generate( | ||||
|         batched_tokens, | ||||
|         np.ones(total_batch, dtype=np.uint32) * provided_ctx, | ||||
|         np.ones(total_batch, dtype=np.uint32) * gen_len, | ||||
|         numseqs, | ||||
|         batched_generator_params, | ||||
|         generator_params, | ||||
|         soft_embeddings=soft_embeddings, | ||||
|     )[0] | ||||
|     for o in output: | ||||
|         samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) | ||||
|     for out in output: | ||||
|         samples.append(out[0][0, 0, params["seq"] : params["seq"] + gen_len]) | ||||
|     return samples | ||||
|  | ||||
|  | ||||
| @@ -414,6 +421,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) | ||||
|     shard_xmap = __shard_xmap() | ||||
|     batch_xmap = __batch_xmap(shard_dim=cores_per_replica) | ||||
|  | ||||
|     global cpu, sample_jit | ||||
|     cpu = jax.devices("cpu")[0] | ||||
|     sample_jit = jax.jit( | ||||
|         sample_jit, | ||||
|         device=cpu, | ||||
|     ) | ||||
|  | ||||
|     global badwords | ||||
|     # These are the tokens that we don't want the AI to ever write | ||||
|     badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) | ||||
|  | ||||
|     if not path.endswith("/"): | ||||
|         path += "/" | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann