Merge branch 'united' into world-info
This commit is contained in:
commit
683bcb824f
25
aiserver.py
25
aiserver.py
|
@ -1838,12 +1838,25 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||
|
||||
soft_tokens = None
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = np.arange(
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||
dtype=np.uint32
|
||||
if(vars.sp is None):
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tensor
|
||||
soft_tokens = np.arange(
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||
dtype=np.uint32
|
||||
)
|
||||
|
||||
genout = tpu_mtj_backend.infer(
|
||||
txt,
|
||||
|
@ -2806,7 +2819,7 @@ def spRequest(filename):
|
|||
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"])
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
|
|
|
@ -4,11 +4,11 @@ requests
|
|||
optax >= 0.0.5, <= 0.0.9
|
||||
dm-haiku
|
||||
ray[default]
|
||||
jax == 0.2.12
|
||||
jax == 0.2.21
|
||||
transformers
|
||||
progressbar2
|
||||
git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck
|
||||
flask
|
||||
Flask-SocketIO
|
||||
flask-cloudflared >= 0.0.5
|
||||
flask-ngrok
|
||||
flask-ngrok
|
||||
|
|
|
@ -30,7 +30,7 @@ def show_spinner():
|
|||
|
||||
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||
'''
|
||||
This gets called by generate_scan_fn to apply repetition penalty
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
|
@ -52,9 +52,9 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
|||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||
'''
|
||||
This gets called by generate_scan_fn to apply a series of 4 filters
|
||||
This gets called by generate_loop_fn to apply a series of 4 filters
|
||||
to the logits (top-k, then top-p, then TFS, then temperature) before
|
||||
picking one token using the modified logits
|
||||
'''
|
||||
|
@ -147,7 +147,7 @@ def kobold_sample(key, logits, _, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||
# Finally, pick one token using the softmax thingy again (it gives
|
||||
# an array whose elements sum to 1 so it can be used nicely as a
|
||||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis], None
|
||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
|
||||
|
||||
pad_token_id = 50256
|
||||
|
||||
|
@ -155,27 +155,34 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
def __init__(self, config):
|
||||
# Initialize
|
||||
super().__init__(config)
|
||||
def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None):
|
||||
gen_length = self.gen_length
|
||||
def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
# These are the tokens that we don't want the AI to ever write
|
||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||
def generate_sample(context, ctx_length, aux):
|
||||
def generate_sample(context, ctx_length):
|
||||
# Give the initial context to the transformer
|
||||
transformer = CausalTransformerShard(config)
|
||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||
# The "generated" array will contain the tokens from the
|
||||
# context as well as the tokens picked by the sampler at
|
||||
# each stage, padded with a bunch of 50256s, so we know
|
||||
# which tokens have to be repetition penalized
|
||||
generated = jnp.pad(context, (0, gen_length), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus gen_length 50256s which will be eventually filled with sampler-chosen tokens
|
||||
generated_index = config["seq"]
|
||||
# Add that information to generate_scan_fn's starting state
|
||||
initial_state = (generated, generated_index) + initial_state
|
||||
def generate_initial_scan_fn(sequence_index, _):
|
||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||
# The "generated" array will contain the tokens from the
|
||||
# context as well as the tokens picked by the sampler at
|
||||
# each stage, padded with a bunch of 50256s, so we know
|
||||
# which tokens have to be repetition penalized
|
||||
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
||||
generated_index = config["seq"]
|
||||
# Add that information to generate_loop_fn's starting state
|
||||
initial_state = (generated, generated_index, sequence_index) + initial_state
|
||||
return sequence_index+1, initial_state
|
||||
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
|
||||
sample_key = initial_states[-1][0]
|
||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
||||
# Get repetition penalty from the arguments
|
||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
||||
def generate_scan_fn(carry, sampler_input):
|
||||
# Unpack current generate_scan_fn state
|
||||
generated, generated_index, next_token, decode_state, sample_key = carry
|
||||
# This is the main generation loop
|
||||
def generate_loop_fn(carry):
|
||||
# Unpack current generate_loop_fn state
|
||||
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
||||
sample_key = carry[1]
|
||||
# Get the pseudo-random number generator key that will
|
||||
# be used by kobold_sample to randomly pick a token
|
||||
sample_key, new_key = jax.random.split(sample_key)
|
||||
|
@ -207,56 +214,54 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
# based on the logits array as a 1D array with 1 element
|
||||
# (higher logit means higher probability of being
|
||||
# picked, non-linearly)
|
||||
next_token, sample_info = kobold_sample(
|
||||
next_token = kobold_sample(
|
||||
sample_key,
|
||||
logits,
|
||||
sampler_input,
|
||||
**sampler_options,
|
||||
)
|
||||
# Remember what token was picked so we can repetition
|
||||
# penalize it next time
|
||||
# Remember what token was picked
|
||||
generated = generated.at[generated_index].set(next_token[0])
|
||||
generated_index += 1
|
||||
# self.return_logits isn't used in this program, but
|
||||
# for the sake of compatibility...
|
||||
if self.return_logits:
|
||||
output = (next_token, sample_info, logits[jnp.newaxis])
|
||||
else:
|
||||
output = (next_token, sample_info)
|
||||
# Re-pack the current generate_scan_fn's state so we can
|
||||
# Re-pack the current generate_loop_fn's state so we can
|
||||
# get back the same variables the next time
|
||||
new_carry = (generated, generated_index, next_token, new_state, new_key)
|
||||
return new_carry, output
|
||||
# jax.lax.scan is a function that calls generate_scan_fn
|
||||
# gen_length times, each time passing a state object from
|
||||
# its return value (new_carry) back into one of the
|
||||
# function's arguments (carry), and of course gathering the
|
||||
# token it generates each time into the "outputs" array;
|
||||
# we have to use jax.lax.scan instead of a normal loop
|
||||
# because of JAX's JIT-compilation shenanigans
|
||||
final_state, outputs = jax.lax.scan(
|
||||
generate_scan_fn,
|
||||
initial_state,
|
||||
xs=aux,
|
||||
length=gen_length,
|
||||
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
|
||||
carry[0].append(carry[0].pop(0))
|
||||
return carry[0], new_key
|
||||
final_state = jax.lax.while_loop(
|
||||
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
|
||||
generate_loop_fn,
|
||||
(initial_states, sample_key),
|
||||
)
|
||||
return final_state, outputs
|
||||
return final_state
|
||||
generate_fn = hk.transform(generate_sample).apply
|
||||
return generate_fn(state["params"], key, ctx, ctx_length, aux)
|
||||
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'})
|
||||
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None):
|
||||
return generate_fn(state["params"], key, ctx, ctx_length)
|
||||
self.generate_xmap = jax.experimental.maps.xmap(
|
||||
fun=generate,
|
||||
in_axes=(
|
||||
["shard", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["batch", ...],
|
||||
["shard", ...],
|
||||
),
|
||||
out_axes=["shard", "batch", ...],
|
||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||
)
|
||||
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
||||
assert not return_logits
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
|
||||
self.gen_length = gen_length
|
||||
self.batch_size = batch_size
|
||||
self.return_logits = return_logits
|
||||
return self.generate_xmap(
|
||||
self.state,
|
||||
jnp.array(key.take(batch_size)),
|
||||
ctx,
|
||||
np.array(ctx_length, dtype=np.uint32),
|
||||
aux,
|
||||
np.array(gen_length, dtype=np.uint32),
|
||||
np.empty((batch_size, numseqs), dtype=np.uint8),
|
||||
sampler_options,
|
||||
soft_embeddings,
|
||||
)
|
||||
|
@ -275,7 +280,7 @@ def infer(
|
|||
soft_tokens: Optional[np.array] = None,
|
||||
) -> List[str]:
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
total_batch = numseqs
|
||||
total_batch = 1
|
||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
||||
if(soft_tokens is not None):
|
||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||
|
@ -283,7 +288,6 @@ def infer(
|
|||
pad_amount = seq - provided_ctx
|
||||
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||
length = np.ones(total_batch, dtype=np.uint32) * provided_ctx
|
||||
samples = []
|
||||
batched_generator_params = {
|
||||
"temp": temp * np.ones(total_batch),
|
||||
|
@ -294,14 +298,14 @@ def infer(
|
|||
}
|
||||
output = network.generate(
|
||||
batched_tokens,
|
||||
length,
|
||||
gen_len,
|
||||
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||
numseqs,
|
||||
batched_generator_params,
|
||||
soft_embeddings=soft_embeddings,
|
||||
)
|
||||
decoded_tokens = output[1][0]
|
||||
for o in decoded_tokens[:, :, 0]:
|
||||
samples.append(tokenizer.decode(o))
|
||||
)[0]
|
||||
for o in output:
|
||||
samples.append(tokenizer.decode(o[0][0, 0, params["seq"] : params["seq"] + gen_len]))
|
||||
return samples
|
||||
|
||||
|
||||
|
@ -326,6 +330,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
if param not in params:
|
||||
params[param] = default_params[param]
|
||||
|
||||
# Disable JAX warnings about these two functions having been renamed
|
||||
jax.host_count = jax.process_count
|
||||
jax.host_id = jax.process_index
|
||||
|
||||
print("Connecting to your Colab instance's TPU", flush=True)
|
||||
spinner = multiprocessing.Process(target=show_spinner, args=())
|
||||
spinner.start()
|
||||
|
@ -342,7 +350,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
params["optimizer"] = optax.scale(0)
|
||||
mesh_shape = (1, cores_per_replica)
|
||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
|
||||
|
|
Loading…
Reference in New Issue