mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-04-03 05:01:02 +02:00
Move sampling into a jax.jit
ted function
This commit is contained in:
parent
09c4fdcb2e
commit
57a6886007
@ -172,36 +172,74 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
|
||||
|
||||
pad_token_id = 50256
|
||||
|
||||
def sample_jit(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
gi = data[0][1]
|
||||
def sample_loop_fn(carry):
|
||||
generated, generated_index, logits, _ = carry[0][0]
|
||||
sample_key = carry[1]
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
logits = apply_repetition_penalty(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
logits = logits.at[badwords].set(-jnp.inf)
|
||||
# Use the sampler (kobold_sample) to pick one token
|
||||
# based on the logits array as a 0D uint32 array
|
||||
# (higher logit means higher probability of being
|
||||
# picked, non-linearly)
|
||||
next_token = kobold_sample(
|
||||
sample_key,
|
||||
logits,
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
generated = generated.at[generated_index].set(next_token)
|
||||
generated_index += 1
|
||||
# Re-pack the current sample_loop_fn's state so we can
|
||||
# get back the same variables the next time
|
||||
carry[0][0] = [generated, generated_index, logits, next_token]
|
||||
carry[0].append(carry[0].pop(0))
|
||||
return carry[0], new_key
|
||||
return jax.lax.while_loop(
|
||||
lambda carry: carry[0][0][1] < gi,
|
||||
sample_loop_fn,
|
||||
(data, key),
|
||||
)
|
||||
|
||||
class PenalizingCausalTransformer(CausalTransformer):
|
||||
def __init__(self, config):
|
||||
# Initialize
|
||||
super().__init__(config)
|
||||
def generate_initial(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||
def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
@hk.transform
|
||||
def generate_initial_inner(context, ctx_length):
|
||||
# Give the initial context to the transformer
|
||||
transformer = CausalTransformerShard(config)
|
||||
def generate_initial_scan_fn(sequence_index, _):
|
||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||
# The "generated" array will contain the tokens from the
|
||||
# context as well as the tokens picked by the sampler at
|
||||
# each stage, padded with a bunch of 50256s, so we know
|
||||
# which tokens have to be repetition penalized
|
||||
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
||||
generated_index = config["seq"]
|
||||
# Add that information to generate_loop_fn's starting state
|
||||
initial_state = (generated, generated_index, sequence_index) + initial_state
|
||||
initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
|
||||
return sequence_index+1, initial_state
|
||||
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
|
||||
sample_key = initial_states[-1][0]
|
||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||
return initial_states, sample_key
|
||||
generate_initial_fn = hk.transform(generate_initial_inner).apply
|
||||
return generate_initial_fn(state["params"], key, ctx, ctx_length)
|
||||
return generate_initial_inner.apply(state["params"], key, ctx, ctx_length)
|
||||
self.generate_initial_xmap = jax.experimental.maps.xmap(
|
||||
fun=generate_initial,
|
||||
in_axes=(
|
||||
@ -210,31 +248,23 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["shard", ...],
|
||||
),
|
||||
out_axes=["shard", "batch", ...],
|
||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||
)
|
||||
def generate_once(initial_states, sample_key, state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||
def generate_once(data, state, numseqs_aux, soft_embeddings=None):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
# These are the tokens that we don't want the AI to ever write
|
||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||
def generate_once_inner(context, ctx_length):
|
||||
gi = initial_states[0][1]
|
||||
@hk.without_apply_rng
|
||||
@hk.transform
|
||||
def generate_once_inner():
|
||||
gi = data[0][1]
|
||||
# Give the initial context to the transformer
|
||||
transformer = CausalTransformerShard(config)
|
||||
# Get repetition penalty from the arguments
|
||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||
# This is the main generation loop
|
||||
def generate_loop_fn(carry):
|
||||
# Unpack current generate_loop_fn state
|
||||
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
||||
sample_key = carry[1]
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key)
|
||||
_, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
||||
# Give the context to the model and get the logits it
|
||||
# spits out
|
||||
# (a 2D array with 1 row and 50400 columns representing
|
||||
@ -245,57 +275,27 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
# Verify that logits does indeed have that many rows and
|
||||
# columns (if you get an error here, pray for mercy)
|
||||
assert logits.shape == (1, config["n_vocab"])
|
||||
assert logits.dtype == jnp.float32
|
||||
# Flatten it into a 1D array to make it easier to use
|
||||
logits = logits[0]
|
||||
# Apply repetition penalty to all tokens that are
|
||||
# currently inside the "generated" array
|
||||
if repetition_penalty is not None:
|
||||
logits = apply_repetition_penalty(
|
||||
logits,
|
||||
generated,
|
||||
repetition_penalty
|
||||
)
|
||||
# Remove any tokens in the badwords list by setting
|
||||
# their logits to negative infinity which effectively
|
||||
# makes their probabilities of being chosen zero
|
||||
logits = logits.at[self.badwords].set(-jnp.inf)
|
||||
# Use the sampler (kobold_sample) to pick one token
|
||||
# based on the logits array as a 1D array with 1 element
|
||||
# (higher logit means higher probability of being
|
||||
# picked, non-linearly)
|
||||
next_token = kobold_sample(
|
||||
sample_key,
|
||||
logits,
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked
|
||||
generated = generated.at[generated_index].set(next_token[0])
|
||||
generated_index += 1
|
||||
# Re-pack the current generate_loop_fn's state so we can
|
||||
# get back the same variables the next time
|
||||
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
|
||||
generated_index += 1
|
||||
carry[0][0] = (logits, generated_index, sequence_index, next_token, new_state)
|
||||
carry[0].append(carry[0].pop(0))
|
||||
return carry[0], new_key
|
||||
final_state = jax.lax.while_loop(
|
||||
return carry[0],
|
||||
return jax.lax.while_loop(
|
||||
lambda carry: carry[0][0][1] == gi,
|
||||
generate_loop_fn,
|
||||
(initial_states, sample_key),
|
||||
(data,),
|
||||
)
|
||||
return final_state
|
||||
generate_once_fn = hk.transform(generate_once_inner).apply
|
||||
return generate_once_fn(state["params"], key, ctx, ctx_length)
|
||||
return generate_once_inner.apply(state["params"])
|
||||
self.generate_once_xmap = jax.experimental.maps.xmap(
|
||||
fun=generate_once,
|
||||
in_axes=(
|
||||
["shard", "batch", ...],
|
||||
["shard", "batch", ...],
|
||||
["shard", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["shard", ...],
|
||||
),
|
||||
out_axes=["shard", "batch", ...],
|
||||
@ -304,23 +304,30 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
||||
assert not return_logits
|
||||
assert gen_length.ndim == 1
|
||||
assert soft_embeddings is not None
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
xargs = [
|
||||
self.state,
|
||||
batch_xmap(jnp.array(key.take(batch_size))),
|
||||
batch_xmap(ctx),
|
||||
batch_xmap(np.array(ctx_length, dtype=np.uint32)),
|
||||
batch_xmap(np.array(gen_length, dtype=np.uint32)),
|
||||
np.empty((batch_size, numseqs), dtype=np.uint8),
|
||||
batch_xmap(sampler_options),
|
||||
shard_xmap(soft_embeddings),
|
||||
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
|
||||
numseqs_aux = batch_xmap(_numseqs_aux)
|
||||
sample_data = [
|
||||
[
|
||||
jnp.pad(ctx, (0, params["seq"]), constant_values=pad_token_id),
|
||||
params["seq"],
|
||||
None,
|
||||
jnp.empty((), dtype=jnp.uint32),
|
||||
]
|
||||
for _ in range(numseqs)
|
||||
]
|
||||
initial_state, sample_key = self.generate_initial_xmap(*xargs)
|
||||
for i in range(gen_length[0]):
|
||||
initial_state, sample_key = self.generate_once_xmap(initial_state, sample_key, *xargs)
|
||||
return initial_state, sample_key
|
||||
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
|
||||
sample_key = jax.device_put(sample_key[0, 0], cpu)
|
||||
for _ in range(gen_length[0].item()):
|
||||
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
||||
for i in range(numseqs):
|
||||
sample_data[i][2] = jax.device_put(generate_data[0][i][0, 0], cpu)
|
||||
sample_data, sample_key = sample_jit(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||
return sample_data, sample_key
|
||||
|
||||
|
||||
def infer(
|
||||
@ -345,23 +352,23 @@ def infer(
|
||||
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||
samples = []
|
||||
batched_generator_params = {
|
||||
"temp": temp * np.ones(total_batch),
|
||||
"top_p": top_p * np.ones(total_batch),
|
||||
"tfs": tfs * np.ones(total_batch),
|
||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||
generator_params = {
|
||||
"temp": float(temp),
|
||||
"top_p": float(top_p),
|
||||
"tfs": float(tfs),
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"top_k": int(top_k),
|
||||
}
|
||||
output = network.generate(
|
||||
batched_tokens,
|
||||
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||
numseqs,
|
||||
batched_generator_params,
|
||||
generator_params,
|
||||
soft_embeddings=soft_embeddings,
|
||||
)[0]
|
||||
for o in output:
|
||||
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len])
|
||||
for out in output:
|
||||
samples.append(out[0][0, 0, params["seq"] : params["seq"] + gen_len])
|
||||
return samples
|
||||
|
||||
|
||||
@ -414,6 +421,17 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
shard_xmap = __shard_xmap()
|
||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||
|
||||
global cpu, sample_jit
|
||||
cpu = jax.devices("cpu")[0]
|
||||
sample_jit = jax.jit(
|
||||
sample_jit,
|
||||
device=cpu,
|
||||
)
|
||||
|
||||
global badwords
|
||||
# These are the tokens that we don't want the AI to ever write
|
||||
badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||
|
||||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user