From e068aa9f26dc51b162cdad91bd096bced8e2c543 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sun, 21 Nov 2021 18:08:04 -0500 Subject: [PATCH] Add soft prompt support to TPU backend --- aiserver.py | 32 +++++++++++++++++++++++++++++--- tpu_mtj_backend.py | 46 +++++++++++++++++++++++++++++++++------------- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/aiserver.py b/aiserver.py index 1d33eb27..2eb06c32 100644 --- a/aiserver.py +++ b/aiserver.py @@ -108,6 +108,7 @@ class vars: loadselect = "" # Temporary storage for story filename to load spselect = "" # Temporary storage for soft prompt filename to load sp = None # Current soft prompt tensor (as a NumPy array) + sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt svowname = "" # Filename that was flagged for overwrite confirm saveow = False # Whether or not overwrite confirm has been displayed genseqs = [] # Temporary storage for generated sequences @@ -700,6 +701,8 @@ else: assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) import tpu_mtj_backend tpu_mtj_backend.load_model(vars.custmodpth) + vars.allowsp = True + vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer # Set up Flask routes @@ -1684,10 +1687,17 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: - if(vars.sp is not None): - raise ValueError("Softprompts are not supported by the TPU backend yet") if(vars.dynamicscan): raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") + + soft_tokens = None + if(vars.sp is not None): + soft_tokens = np.arange( + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], + tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, + dtype=np.uint32 + ) + genout = tpu_mtj_backend.infer( txt, gen_len = maximum-minimum+1, @@ -1697,6 +1707,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): tfs=vars.tfs, numseqs=vars.numseqs, repetition_penalty=vars.rep_pen, + soft_embeddings=vars.sp, + soft_tokens=soft_tokens, ) except Exception as e: @@ -2525,6 +2537,7 @@ def loadRequest(loadpath, filename=None): def spRequest(filename): if(len(filename) == 0): vars.sp = None + vars.sp_length = 0 return global np @@ -2548,7 +2561,20 @@ def spRequest(filename): tensor = np.float32(tensor) assert not np.isinf(tensor).any() and not np.isnan(tensor).any() - vars.sp = torch.from_numpy(tensor) + vars.sp_length = tensor.shape[0] + + if(vars.model in ("TPUMeshTransformerGPTJ",)): + rows = tensor.shape[0] + padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"]) + tensor = np.pad(tensor, ((0, padding_amount), (0, 0))) + tensor = tensor.reshape( + tpu_mtj_backend.params["cores_per_replica"], + -1, + tpu_mtj_backend.params["d_model"], + ) + vars.sp = tensor + else: + vars.sp = torch.from_numpy(tensor) #==================================================================# # Import an AIDungon game exported with Mimi's tool diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index f33982ba..4be7575f 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,5 +1,5 @@ import multiprocessing -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import progressbar import time import os @@ -155,14 +155,14 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) - def generate(state, key, ctx, ctx_length, aux, sampler_options): + def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None): gen_length = self.gen_length # These are the tokens that we don't want the AI to ever write self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) def generate_sample(context, ctx_length, aux): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - _, initial_state = transformer.generate_initial(context, ctx_length) + _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) # The "generated" array will contain the tokens from the # context as well as the tokens picked by the sampler at # each stage, padded with a bunch of 50256s, so we know @@ -185,7 +185,7 @@ class PenalizingCausalTransformer(CausalTransformer): # how strongly it thinks each of the 50257 tokens in its # vocabulary should be appended to the context, followed # by 143 apparently useless columns ???) - logits, new_state = transformer.generate_once(next_token, decode_state) + logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings) # Verify that logits does indeed have that many rows and # columns (if you get an error here, pray for mercy) assert logits.shape == (1, config["n_vocab"]) @@ -243,8 +243,8 @@ class PenalizingCausalTransformer(CausalTransformer): return final_state, outputs generate_fn = hk.transform(generate_sample).apply return generate_fn(state["params"], key, ctx, ctx_length, aux) - self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}) - def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False): + self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}) + def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None): key = hk.PRNGSequence(random.randint(0, 2 ** 60)) batch_size = ctx.shape[0] aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32) @@ -257,19 +257,33 @@ class PenalizingCausalTransformer(CausalTransformer): ctx, np.array(ctx_length, dtype=np.uint32), aux, - sampler_options + sampler_options, + soft_embeddings, ) -def infer(context, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, repetition_penalty=1.0, numseqs=1, gen_len=80) -> List[str]: +def infer( + context: str, + top_p=0.9, + temp=0.5, + top_k=0, + tfs=1.0, + repetition_penalty=1.0, + numseqs=1, + gen_len=80, + soft_embeddings: Optional[np.array] = None, + soft_tokens: Optional[np.array] = None, +) -> List[str]: maps.thread_resources.env = thread_resources_env total_batch = numseqs - tokens = tokenizer.encode(context, max_length=params["seq"], truncation=True) - provided_ctx = len(tokens) + tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - soft_tokens.shape[0] if soft_tokens is not None else 0, truncation=True)) + if(soft_tokens is not None): + tokens = np.uint32(np.concatenate((soft_tokens, tokens))) + provided_ctx = tokens.shape[0] pad_amount = seq - provided_ctx - padded_tokens = np.pad(np.asarray(tokens, dtype=np.uint32), ((pad_amount, 0),), constant_values=pad_token_id) + padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) - length = np.ones(total_batch, dtype=np.uint32) * len(tokens) + length = np.ones(total_batch, dtype=np.uint32) * provided_ctx samples = [] batched_generator_params = { "temp": temp * np.ones(total_batch), @@ -278,7 +292,13 @@ def infer(context, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, repetition_penalty=1.0 "repetition_penalty": repetition_penalty * np.ones(total_batch), "top_k": np.full(total_batch, top_k, dtype=np.uint32) } - output = network.generate(batched_tokens, length, gen_len, batched_generator_params) + output = network.generate( + batched_tokens, + length, + gen_len, + batched_generator_params, + soft_embeddings=soft_embeddings, + ) decoded_tokens = output[1][0] for o in decoded_tokens[:, :, 0]: samples.append(tokenizer.decode(o))