mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-05 03:48:07 +01:00
Modify TPU backend code to support JAX 0.2.21
The original one supported versions of JAX up to 0.2.12, and possibly also some earlier versions. This new code supports exclusively JAX 0.2.21 and does not work with any earlier or later versions of JAX. However, this new code benefits from not needing to recompile when changing "Amount To Generate" and also from supporting stopping generation early, which makes an implementation of Dynamic World Info Scan finally possible.
This commit is contained in:
parent
9e3318c696
commit
3c349e6aaf
@ -4,11 +4,11 @@ requests
|
|||||||
optax >= 0.0.5, <= 0.0.9
|
optax >= 0.0.5, <= 0.0.9
|
||||||
dm-haiku
|
dm-haiku
|
||||||
ray[default]
|
ray[default]
|
||||||
jax == 0.2.12
|
jax == 0.2.21
|
||||||
transformers
|
transformers
|
||||||
progressbar2
|
progressbar2
|
||||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||||
flask
|
flask
|
||||||
Flask-SocketIO
|
Flask-SocketIO
|
||||||
flask-cloudflared >= 0.0.5
|
flask-cloudflared >= 0.0.5
|
||||||
flask-ngrok
|
flask-ngrok
|
||||||
|
@ -30,7 +30,7 @@ def show_spinner():
|
|||||||
|
|
||||||
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_scan_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||||
'''
|
'''
|
||||||
# Make a new array with the same length as the tokens array but with
|
# Make a new array with the same length as the tokens array but with
|
||||||
@ -52,9 +52,9 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
|||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_scan_fn to apply a series of 4 filters
|
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||||
picking one token using the modified logits
|
picking one token using the modified logits
|
||||||
'''
|
'''
|
||||||
@ -147,7 +147,7 @@ def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None
|
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
|
|
||||||
@ -155,11 +155,10 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None):
|
def generate(state, key, ctx, ctx_length, gen_length, sampler_options, soft_embeddings=None):
|
||||||
gen_length = self.gen_length
|
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||||
def generate_sample(context, ctx_length, aux):
|
def generate_sample(context, ctx_length):
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
transformer = CausalTransformerShard(config)
|
transformer = CausalTransformerShard(config)
|
||||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||||
@ -167,14 +166,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
# context as well as the tokens picked by the sampler at
|
# context as well as the tokens picked by the sampler at
|
||||||
# each stage, padded with a bunch of 50256s, so we know
|
# each stage, padded with a bunch of 50256s, so we know
|
||||||
# which tokens have to be repetition penalized
|
# which tokens have to be repetition penalized
|
||||||
generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens
|
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
||||||
generated_index = config["seq"]
|
generated_index = config["seq"]
|
||||||
# Add that information to generate_scan_fn's starting state
|
# Add that information to generate_loop_fn's starting state
|
||||||
initial_state = (generated, generated_index) + initial_state
|
initial_state = (generated, generated_index) + initial_state
|
||||||
# Get repetition penalty from the arguments
|
# Get repetition penalty from the arguments
|
||||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||||
def generate_scan_fn(carry, sampler_input):
|
def generate_loop_fn(carry):
|
||||||
# Unpack current generate_scan_fn state
|
# Unpack current generate_loop_fn state
|
||||||
generated, generated_index, next_token, decode_state, sample_key = carry
|
generated, generated_index, next_token, decode_state, sample_key = carry
|
||||||
# Get the pseudo-random number generator key that will
|
# Get the pseudo-random number generator key that will
|
||||||
# be used by kobold_sample to randomly pick a token
|
# be used by kobold_sample to randomly pick a token
|
||||||
@ -207,56 +206,51 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
# based on the logits array as a 1D array with 1 element
|
# based on the logits array as a 1D array with 1 element
|
||||||
# (higher logit means higher probability of being
|
# (higher logit means higher probability of being
|
||||||
# picked, non-linearly)
|
# picked, non-linearly)
|
||||||
next_token, sample_info = kobold_sample(
|
next_token = kobold_sample(
|
||||||
sample_key,
|
sample_key,
|
||||||
logits,
|
logits,
|
||||||
sampler_input,
|
|
||||||
**sampler_options,
|
**sampler_options,
|
||||||
)
|
)
|
||||||
# Remember what token was picked so we can repetition
|
# Remember what token was picked
|
||||||
# penalize it next time
|
|
||||||
generated = generated.at[generated_index].set(next_token[0])
|
generated = generated.at[generated_index].set(next_token[0])
|
||||||
generated_index += 1
|
generated_index += 1
|
||||||
# self.return_logits isn't used in this program, but
|
# Re-pack the current generate_loop_fn's state so we can
|
||||||
# for the sake of compatibility...
|
|
||||||
if self.return_logits:
|
|
||||||
output = (next_token, sample_info, logits[jnp.newaxis])
|
|
||||||
else:
|
|
||||||
output = (next_token, sample_info)
|
|
||||||
# Re-pack the current generate_scan_fn's state so we can
|
|
||||||
# get back the same variables the next time
|
# get back the same variables the next time
|
||||||
new_carry = (generated, generated_index, next_token, new_state, new_key)
|
new_carry = (generated, generated_index, next_token, new_state, new_key)
|
||||||
return new_carry, output
|
return new_carry
|
||||||
# jax.lax.scan is a function that calls generate_scan_fn
|
final_state = jax.lax.while_loop(
|
||||||
# gen_length times, each time passing a state object from
|
lambda carry: carry[1] - config["seq"] < gen_length,
|
||||||
# its return value (new_carry) back into one of the
|
generate_loop_fn,
|
||||||
# function's arguments (carry), and of course gathering the
|
|
||||||
# token it generates each time into the "outputs" array;
|
|
||||||
# we have to use jax.lax.scan instead of a normal loop
|
|
||||||
# because of JAX's JIT-compilation shenanigans
|
|
||||||
final_state, outputs = jax.lax.scan(
|
|
||||||
generate_scan_fn,
|
|
||||||
initial_state,
|
initial_state,
|
||||||
xs=aux,
|
|
||||||
length=gen_length,
|
|
||||||
)
|
)
|
||||||
return final_state, outputs
|
return final_state
|
||||||
generate_fn = hk.transform(generate_sample).apply
|
generate_fn = hk.transform(generate_sample).apply
|
||||||
return generate_fn(state["params"], key, ctx, ctx_length, aux)
|
return generate_fn(state["params"], key, ctx, ctx_length)
|
||||||
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'})
|
self.generate_xmap = jax.experimental.maps.xmap(
|
||||||
|
fun=generate,
|
||||||
|
in_axes=(
|
||||||
|
["shard", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["shard", ...],
|
||||||
|
),
|
||||||
|
out_axes=["shard", "batch", ...],
|
||||||
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
|
)
|
||||||
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None):
|
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None):
|
||||||
|
assert not return_logits
|
||||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
batch_size = ctx.shape[0]
|
batch_size = ctx.shape[0]
|
||||||
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
|
|
||||||
self.gen_length = gen_length
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.return_logits = return_logits
|
|
||||||
return self.generate_xmap(
|
return self.generate_xmap(
|
||||||
self.state,
|
self.state,
|
||||||
jnp.array(key.take(batch_size)),
|
jnp.array(key.take(batch_size)),
|
||||||
ctx,
|
ctx,
|
||||||
np.array(ctx_length, dtype=np.uint32),
|
np.array(ctx_length, dtype=np.uint32),
|
||||||
aux,
|
np.array(gen_length, dtype=np.uint32),
|
||||||
sampler_options,
|
sampler_options,
|
||||||
soft_embeddings,
|
soft_embeddings,
|
||||||
)
|
)
|
||||||
@ -283,7 +277,6 @@ def infer(
|
|||||||
pad_amount = seq - provided_ctx
|
pad_amount = seq - provided_ctx
|
||||||
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
length = np.ones(total_batch, dtype=np.uint32) * provided_ctx
|
|
||||||
samples = []
|
samples = []
|
||||||
batched_generator_params = {
|
batched_generator_params = {
|
||||||
"temp": temp * np.ones(total_batch),
|
"temp": temp * np.ones(total_batch),
|
||||||
@ -294,13 +287,12 @@ def infer(
|
|||||||
}
|
}
|
||||||
output = network.generate(
|
output = network.generate(
|
||||||
batched_tokens,
|
batched_tokens,
|
||||||
length,
|
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||||
gen_len,
|
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||||
batched_generator_params,
|
batched_generator_params,
|
||||||
soft_embeddings=soft_embeddings,
|
soft_embeddings=soft_embeddings,
|
||||||
)
|
)
|
||||||
decoded_tokens = output[1][0]
|
for o in output[0][0, :, params["seq"] : params["seq"] + gen_len]:
|
||||||
for o in decoded_tokens[:, :, 0]:
|
|
||||||
samples.append(tokenizer.decode(o))
|
samples.append(tokenizer.decode(o))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -326,6 +318,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||||||
if param not in params:
|
if param not in params:
|
||||||
params[param] = default_params[param]
|
params[param] = default_params[param]
|
||||||
|
|
||||||
|
# Disable JAX warnings about these two functions having been renamed
|
||||||
|
jax.host_count = jax.process_count
|
||||||
|
jax.host_id = jax.process_index
|
||||||
|
|
||||||
print("Connecting to your Colab instance's TPU", flush=True)
|
print("Connecting to your Colab instance's TPU", flush=True)
|
||||||
spinner = multiprocessing.Process(target=show_spinner, args=())
|
spinner = multiprocessing.Process(target=show_spinner, args=())
|
||||||
spinner.start()
|
spinner.start()
|
||||||
@ -342,7 +338,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||||||
params["optimizer"] = optax.scale(0)
|
params["optimizer"] = optax.scale(0)
|
||||||
mesh_shape = (1, cores_per_replica)
|
mesh_shape = (1, cores_per_replica)
|
||||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user