Improve TPU backend compilation times with `numseqs > 1`

A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX
only compiles the `transformer.generate_initial` function one time
instead of `numseqs` times. This is because JAX unrolls Python built-in
loops like `for`. The compilation times should now be about the same as
they were before the upgrade to JAX 0.2.21.
This commit is contained in:
Gnome Ann 2021-11-30 19:22:40 -05:00
parent c1e7c1643f
commit d2d338d314
1 changed files with 14 additions and 10 deletions

View File

@ -162,8 +162,7 @@ class PenalizingCausalTransformer(CausalTransformer):
def generate_sample(context, ctx_length):
# Give the initial context to the transformer
transformer = CausalTransformerShard(config)
initial_states = []
for sequence_index in range(numseqs):
def generate_initial_scan_fn(sequence_index, _):
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
# The "generated" array will contain the tokens from the
# context as well as the tokens picked by the sampler at
@ -173,12 +172,17 @@ class PenalizingCausalTransformer(CausalTransformer):
generated_index = config["seq"]
# Add that information to generate_loop_fn's starting state
initial_state = (generated, generated_index, sequence_index) + initial_state
initial_states.append(initial_state)
return sequence_index+1, initial_state
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
sample_key = initial_states[-1][0]
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
# Get repetition penalty from the arguments
repetition_penalty = sampler_options.pop('repetition_penalty', None)
# This is the main generation loop
def generate_loop_fn(carry):
# Unpack current generate_loop_fn state
generated, generated_index, sequence_index, next_token, decode_state, sample_key = carry[0]
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key)
@ -220,13 +224,13 @@ class PenalizingCausalTransformer(CausalTransformer):
generated_index += 1
# Re-pack the current generate_loop_fn's state so we can
# get back the same variables the next time
carry[0] = (generated, generated_index, sequence_index, next_token, new_state, new_key)
carry.append(carry.pop(0))
return carry
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
carry[0].append(carry[0].pop(0))
return carry[0], new_key
final_state = jax.lax.while_loop(
lambda carry: carry[0][1] - config["seq"] < gen_length,
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
generate_loop_fn,
initial_states,
(initial_states, sample_key),
)
return final_state
generate_fn = hk.transform(generate_sample).apply
@ -299,7 +303,7 @@ def infer(
numseqs,
batched_generator_params,
soft_embeddings=soft_embeddings,
)
)[0]
for o in output:
samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
return samples