mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-04-20 13:07:29 +02:00
Fix unbound axis error in tpu_mtj_backend.py when numseqs > 1
This commit is contained in:
parent
3c349e6aaf
commit
c1e7c1643f
@ -155,26 +155,30 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
def generate(state, key, ctx, ctx_length, gen_length, sampler_options, soft_embeddings=None):
|
def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||||
|
numseqs = numseqs_aux.shape[0]
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||||
def generate_sample(context, ctx_length):
|
def generate_sample(context, ctx_length):
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
transformer = CausalTransformerShard(config)
|
transformer = CausalTransformerShard(config)
|
||||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
initial_states = []
|
||||||
# The "generated" array will contain the tokens from the
|
for sequence_index in range(numseqs):
|
||||||
# context as well as the tokens picked by the sampler at
|
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||||
# each stage, padded with a bunch of 50256s, so we know
|
# The "generated" array will contain the tokens from the
|
||||||
# which tokens have to be repetition penalized
|
# context as well as the tokens picked by the sampler at
|
||||||
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
# each stage, padded with a bunch of 50256s, so we know
|
||||||
generated_index = config["seq"]
|
# which tokens have to be repetition penalized
|
||||||
# Add that information to generate_loop_fn's starting state
|
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
||||||
initial_state = (generated, generated_index) + initial_state
|
generated_index = config["seq"]
|
||||||
|
# Add that information to generate_loop_fn's starting state
|
||||||
|
initial_state = (generated, generated_index, sequence_index) + initial_state
|
||||||
|
initial_states.append(initial_state)
|
||||||
# Get repetition penalty from the arguments
|
# Get repetition penalty from the arguments
|
||||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||||
def generate_loop_fn(carry):
|
def generate_loop_fn(carry):
|
||||||
# Unpack current generate_loop_fn state
|
# Unpack current generate_loop_fn state
|
||||||
generated, generated_index, next_token, decode_state, sample_key = carry
|
generated, generated_index, sequence_index, next_token, decode_state, sample_key = carry[0]
|
||||||
# Get the pseudo-random number generator key that will
|
# Get the pseudo-random number generator key that will
|
||||||
# be used by kobold_sample to randomly pick a token
|
# be used by kobold_sample to randomly pick a token
|
||||||
sample_key, new_key = jax.random.split(sample_key)
|
sample_key, new_key = jax.random.split(sample_key)
|
||||||
@ -216,12 +220,13 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
generated_index += 1
|
generated_index += 1
|
||||||
# Re-pack the current generate_loop_fn's state so we can
|
# Re-pack the current generate_loop_fn's state so we can
|
||||||
# get back the same variables the next time
|
# get back the same variables the next time
|
||||||
new_carry = (generated, generated_index, next_token, new_state, new_key)
|
carry[0] = (generated, generated_index, sequence_index, next_token, new_state, new_key)
|
||||||
return new_carry
|
carry.append(carry.pop(0))
|
||||||
|
return carry
|
||||||
final_state = jax.lax.while_loop(
|
final_state = jax.lax.while_loop(
|
||||||
lambda carry: carry[1] - config["seq"] < gen_length,
|
lambda carry: carry[0][1] - config["seq"] < gen_length,
|
||||||
generate_loop_fn,
|
generate_loop_fn,
|
||||||
initial_state,
|
initial_states,
|
||||||
)
|
)
|
||||||
return final_state
|
return final_state
|
||||||
generate_fn = hk.transform(generate_sample).apply
|
generate_fn = hk.transform(generate_sample).apply
|
||||||
@ -235,12 +240,13 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
["batch", ...],
|
["batch", ...],
|
||||||
["batch", ...],
|
["batch", ...],
|
||||||
["batch", ...],
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
["shard", ...],
|
["shard", ...],
|
||||||
),
|
),
|
||||||
out_axes=["shard", "batch", ...],
|
out_axes=["shard", "batch", ...],
|
||||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
)
|
)
|
||||||
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None):
|
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
||||||
assert not return_logits
|
assert not return_logits
|
||||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
batch_size = ctx.shape[0]
|
batch_size = ctx.shape[0]
|
||||||
@ -251,6 +257,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
ctx,
|
ctx,
|
||||||
np.array(ctx_length, dtype=np.uint32),
|
np.array(ctx_length, dtype=np.uint32),
|
||||||
np.array(gen_length, dtype=np.uint32),
|
np.array(gen_length, dtype=np.uint32),
|
||||||
|
np.empty((batch_size, numseqs), dtype=np.uint8),
|
||||||
sampler_options,
|
sampler_options,
|
||||||
soft_embeddings,
|
soft_embeddings,
|
||||||
)
|
)
|
||||||
@ -269,7 +276,7 @@ def infer(
|
|||||||
soft_tokens: Optional[np.array] = None,
|
soft_tokens: Optional[np.array] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
total_batch = numseqs
|
total_batch = 1
|
||||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
||||||
if(soft_tokens is not None):
|
if(soft_tokens is not None):
|
||||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||||
@ -289,11 +296,12 @@ def infer(
|
|||||||
batched_tokens,
|
batched_tokens,
|
||||||
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||||
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||||
|
numseqs,
|
||||||
batched_generator_params,
|
batched_generator_params,
|
||||||
soft_embeddings=soft_embeddings,
|
soft_embeddings=soft_embeddings,
|
||||||
)
|
)
|
||||||
for o in output[0][0, :, params["seq"] : params["seq"] + gen_len]:
|
for o in output:
|
||||||
samples.append(tokenizer.decode(o))
|
samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user