mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-11 01:00:37 +01:00
Improve TPU backend compilation times with numseqs > 1
A Python `for` loop was replaced with a `jax.lax.scan` call so that JAX only compiles the `transformer.generate_initial` function one time instead of `numseqs` times. This is because JAX unrolls Python built-in loops like `for`. The compilation times should now be about the same as they were before the upgrade to JAX 0.2.21.
This commit is contained in:
parent
c1e7c1643f
commit
d2d338d314
@ -162,8 +162,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
def generate_sample(context, ctx_length):
|
||||
# Give the initial context to the transformer
|
||||
transformer = CausalTransformerShard(config)
|
||||
initial_states = []
|
||||
for sequence_index in range(numseqs):
|
||||
def generate_initial_scan_fn(sequence_index, _):
|
||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||
# The "generated" array will contain the tokens from the
|
||||
# context as well as the tokens picked by the sampler at
|
||||
@ -173,12 +172,17 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
generated_index = config["seq"]
|
||||
# Add that information to generate_loop_fn's starting state
|
||||
initial_state = (generated, generated_index, sequence_index) + initial_state
|
||||
initial_states.append(initial_state)
|
||||
return sequence_index+1, initial_state
|
||||
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
|
||||
sample_key = initial_states[-1][0]
|
||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||
# Get repetition penalty from the arguments
|
||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||
# This is the main generation loop
|
||||
def generate_loop_fn(carry):
|
||||
# Unpack current generate_loop_fn state
|
||||
generated, generated_index, sequence_index, next_token, decode_state, sample_key = carry[0]
|
||||
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
||||
sample_key = carry[1]
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key)
|
||||
@ -220,13 +224,13 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
generated_index += 1
|
||||
# Re-pack the current generate_loop_fn's state so we can
|
||||
# get back the same variables the next time
|
||||
carry[0] = (generated, generated_index, sequence_index, next_token, new_state, new_key)
|
||||
carry.append(carry.pop(0))
|
||||
return carry
|
||||
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
|
||||
carry[0].append(carry[0].pop(0))
|
||||
return carry[0], new_key
|
||||
final_state = jax.lax.while_loop(
|
||||
lambda carry: carry[0][1] - config["seq"] < gen_length,
|
||||
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
|
||||
generate_loop_fn,
|
||||
initial_states,
|
||||
(initial_states, sample_key),
|
||||
)
|
||||
return final_state
|
||||
generate_fn = hk.transform(generate_sample).apply
|
||||
@ -299,7 +303,7 @@ def infer(
|
||||
numseqs,
|
||||
batched_generator_params,
|
||||
soft_embeddings=soft_embeddings,
|
||||
)
|
||||
)[0]
|
||||
for o in output:
|
||||
samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
|
||||
return samples
|
||||
|
Loading…
x
Reference in New Issue
Block a user