Merge pull request #41 from VE-FORBRYDERNE/jax21

TPU backend improvements
This commit is contained in:
henk717 2021-12-05 18:10:52 +01:00 committed by GitHub
commit a442a2a67e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 69 deletions

View File

@ -1802,12 +1802,25 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
soft_tokens = None soft_tokens = None
if(vars.sp is not None): if(vars.sp is None):
soft_tokens = np.arange( global np
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], if 'np' not in globals():
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, import numpy as np
dtype=np.uint32 tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
) )
vars.sp = tensor
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
dtype=np.uint32
)
genout = tpu_mtj_backend.infer( genout = tpu_mtj_backend.infer(
txt, txt,
@ -2676,7 +2689,7 @@ def spRequest(filename):
if(vars.model in ("TPUMeshTransformerGPTJ",)): if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0] rows = tensor.shape[0]
padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"]) padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape( tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"], tpu_mtj_backend.params["cores_per_replica"],

View File

@ -4,11 +4,11 @@ requests
optax >= 0.0.5, <= 0.0.9 optax >= 0.0.5, <= 0.0.9
dm-haiku dm-haiku
ray[default] ray[default]
jax == 0.2.12 jax == 0.2.21
transformers transformers
progressbar2 progressbar2
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
flask flask
Flask-SocketIO Flask-SocketIO
flask-cloudflared >= 0.0.5 flask-cloudflared >= 0.0.5
flask-ngrok flask-ngrok

View File

@ -30,7 +30,7 @@ def show_spinner():
def apply_repetition_penalty(logits, tokens, repetition_penalty): def apply_repetition_penalty(logits, tokens, repetition_penalty):
''' '''
This gets called by generate_scan_fn to apply repetition penalty This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize to the 1D array logits using the provided 1D array of tokens to penalize
''' '''
# Make a new array with the same length as the tokens array but with # Make a new array with the same length as the tokens array but with
@ -52,9 +52,9 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
# positions in the logits array # positions in the logits array
return logits.at[tokens].set(penalty_logits) return logits.at[tokens].set(penalty_logits)
def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
''' '''
This gets called by generate_scan_fn to apply a series of 4 filters This gets called by generate_loop_fn to apply a series of 4 filters
to the logits (top-k, then top-p, then TFS, then temperature) before to the logits (top-k, then top-p, then TFS, then temperature) before
picking one token using the modified logits picking one token using the modified logits
''' '''
@ -147,7 +147,7 @@ def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
pad_token_id = 50256 pad_token_id = 50256
@ -155,27 +155,34 @@ class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config): def __init__(self, config):
# Initialize # Initialize
super().__init__(config) super().__init__(config)
def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None): def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
gen_length = self.gen_length numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write # These are the tokens that we don't want the AI to ever write
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
def generate_sample(context, ctx_length, aux): def generate_sample(context, ctx_length):
# Give the initial context to the transformer # Give the initial context to the transformer
transformer = CausalTransformerShard(config) transformer = CausalTransformerShard(config)
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) def generate_initial_scan_fn(sequence_index, _):
# The "generated" array will contain the tokens from the _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
# context as well as the tokens picked by the sampler at # The "generated" array will contain the tokens from the
# each stage, padded with a bunch of 50256s, so we know # context as well as the tokens picked by the sampler at
# which tokens have to be repetition penalized # each stage, padded with a bunch of 50256s, so we know
generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens # which tokens have to be repetition penalized
generated_index = config["seq"] generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
# Add that information to generate_scan_fn's starting state generated_index = config["seq"]
initial_state = (generated, generated_index) + initial_state # Add that information to generate_loop_fn's starting state
initial_state = (generated, generated_index, sequence_index) + initial_state
return sequence_index+1, initial_state
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
sample_key = initial_states[-1][0]
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
# Get repetition penalty from the arguments # Get repetition penalty from the arguments
repetition_penalty = sampler_options.pop('repetition_penalty', None) repetition_penalty = sampler_options.pop('repetition_penalty', None)
def generate_scan_fn(carry, sampler_input): # This is the main generation loop
# Unpack current generate_scan_fn state def generate_loop_fn(carry):
generated, generated_index, next_token, decode_state, sample_key = carry # Unpack current generate_loop_fn state
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will # Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token # be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key) sample_key, new_key = jax.random.split(sample_key)
@ -207,56 +214,54 @@ class PenalizingCausalTransformer(CausalTransformer):
# based on the logits array as a 1D array with 1 element # based on the logits array as a 1D array with 1 element
# (higher logit means higher probability of being # (higher logit means higher probability of being
# picked, non-linearly) # picked, non-linearly)
next_token, sample_info = kobold_sample( next_token = kobold_sample(
sample_key, sample_key,
logits, logits,
sampler_input,
**sampler_options, **sampler_options,
) )
# Remember what token was picked so we can repetition # Remember what token was picked
# penalize it next time
generated = generated.at[generated_index].set(next_token[0]) generated = generated.at[generated_index].set(next_token[0])
generated_index += 1 generated_index += 1
# self.return_logits isn't used in this program, but # Re-pack the current generate_loop_fn's state so we can
# for the sake of compatibility...
if self.return_logits:
output = (next_token, sample_info, logits[jnp.newaxis])
else:
output = (next_token, sample_info)
# Re-pack the current generate_scan_fn's state so we can
# get back the same variables the next time # get back the same variables the next time
new_carry = (generated, generated_index, next_token, new_state, new_key) carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
return new_carry, output carry[0].append(carry[0].pop(0))
# jax.lax.scan is a function that calls generate_scan_fn return carry[0], new_key
# gen_length times, each time passing a state object from final_state = jax.lax.while_loop(
# its return value (new_carry) back into one of the lambda carry: carry[0][0][1] - config["seq"] < gen_length,
# function's arguments (carry), and of course gathering the generate_loop_fn,
# token it generates each time into the "outputs" array; (initial_states, sample_key),
# we have to use jax.lax.scan instead of a normal loop
# because of JAX's JIT-compilation shenanigans
final_state, outputs = jax.lax.scan(
generate_scan_fn,
initial_state,
xs=aux,
length=gen_length,
) )
return final_state, outputs return final_state
generate_fn = hk.transform(generate_sample).apply generate_fn = hk.transform(generate_sample).apply
return generate_fn(state["params"], key, ctx, ctx_length, aux) return generate_fn(state["params"], key, ctx, ctx_length)
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}) self.generate_xmap = jax.experimental.maps.xmap(
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None): fun=generate,
in_axes=(
["shard", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["shard", ...],
),
out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
assert not return_logits
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
self.gen_length = gen_length
self.batch_size = batch_size self.batch_size = batch_size
self.return_logits = return_logits
return self.generate_xmap( return self.generate_xmap(
self.state, self.state,
jnp.array(key.take(batch_size)), jnp.array(key.take(batch_size)),
ctx, ctx,
np.array(ctx_length, dtype=np.uint32), np.array(ctx_length, dtype=np.uint32),
aux, np.array(gen_length, dtype=np.uint32),
np.empty((batch_size, numseqs), dtype=np.uint8),
sampler_options, sampler_options,
soft_embeddings, soft_embeddings,
) )
@ -275,7 +280,7 @@ def infer(
soft_tokens: Optional[np.array] = None, soft_tokens: Optional[np.array] = None,
) -> List[str]: ) -> List[str]:
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
total_batch = numseqs total_batch = 1
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True)) tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
if(soft_tokens is not None): if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens))) tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
@ -283,7 +288,6 @@ def infer(
pad_amount = seq - provided_ctx pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
batched_tokens = np.array([padded_tokens] * total_batch) batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * provided_ctx
samples = [] samples = []
batched_generator_params = { batched_generator_params = {
"temp": temp * np.ones(total_batch), "temp": temp * np.ones(total_batch),
@ -294,14 +298,14 @@ def infer(
} }
output = network.generate( output = network.generate(
batched_tokens, batched_tokens,
length, np.ones(total_batch, dtype=np.uint32) * provided_ctx,
gen_len, np.ones(total_batch, dtype=np.uint32) * gen_len,
numseqs,
batched_generator_params, batched_generator_params,
soft_embeddings=soft_embeddings, soft_embeddings=soft_embeddings,
) )[0]
decoded_tokens = output[1][0] for o in output:
for o in decoded_tokens[:, :, 0]: samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
samples.append(tokenizer.decode(o))
return samples return samples
@ -326,6 +330,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
if param not in params: if param not in params:
params[param] = default_params[param] params[param] = default_params[param]
# Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count
jax.host_id = jax.process_index
print("Connecting to your Colab instance's TPU", flush=True) print("Connecting to your Colab instance's TPU", flush=True)
spinner = multiprocessing.Process(target=show_spinner, args=()) spinner = multiprocessing.Process(target=show_spinner, args=())
spinner.start() spinner.start()
@ -342,7 +350,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
params["optimizer"] = optax.scale(0) params["optimizer"] = optax.scale(0)
mesh_shape = (1, cores_per_replica) mesh_shape = (1, cores_per_replica)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')