From 932c393d6a418da6425646d90ffb7bfef9edacd6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 21:39:02 -0500 Subject: [PATCH] Add TPU support for dynamic WI scan and generation modifiers --- aiserver.py | 121 +++++++++++++++++++++++++++++++++++++-------- tpu_mtj_backend.py | 97 ++++++++++++++++++++---------------- 2 files changed, 155 insertions(+), 63 deletions(-) diff --git a/aiserver.py b/aiserver.py index 83df8c7e..96accfa3 100644 --- a/aiserver.py +++ b/aiserver.py @@ -22,7 +22,7 @@ import packaging import contextlib import traceback import threading -from typing import Any, Callable, TypeVar, Union, Dict, Set, List +from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List import requests import html @@ -1001,6 +1001,46 @@ else: ) return soft_tokens + def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]: + vars.generated_tkns += 1 + + assert len(excluded_world_info) == len(generated) + regeneration_required = vars.lua_koboldbridge.regeneration_required + halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt + vars.lua_koboldbridge.regeneration_required = False + + global past + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) + + scores_shape = scores.shape + scores_list = scores.tolist() + vars.lua_koboldbridge.logits = vars.lua_state.table() + for r, row in enumerate(scores_list): + vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) + vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + execute_genmod() + + scores = np.array( + tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), + dtype=scores.dtype, + ) + assert scores.shape == scores_shape + + if(not vars.dynamicscan or halt): + return excluded_world_info, regeneration_required, halt + + for i, t in enumerate(generated): + decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) + _, found = checkworldinfo(decoded, force_use_txt=True) + found -= excluded_world_info[i] + if(len(found) != 0): + regeneration_required = True + break + return excluded_world_info, regeneration_required, halt + # If we're running Colab or OAI, we still need a tokenizer. if(vars.model == "Colab"): from transformers import GPT2TokenizerFast @@ -1013,6 +1053,7 @@ else: print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) import tpu_mtj_backend + tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) @@ -1020,12 +1061,14 @@ else: soft_tokens = tpumtjgetsofttokens() threading.Thread( # Compile backend code in background target=tpu_mtj_backend.infer, - args=(np.uint32((23403, 727, 20185)),), + args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), kwargs={ "soft_embeddings": vars.sp, "soft_tokens": soft_tokens, + "use_callback": False, "gen_len": 1, "numseqs": vars.numseqs, + "excluded_world_info": list(set() for _ in range(vars.numseqs)), }, ).start() @@ -2890,32 +2933,68 @@ def sendtocolab(txt, min, max): # Send text to TPU mesh transformer backend #==================================================================# def tpumtjgenerate(txt, minimum, maximum, found_entries=None): + vars.generated_tkns = 0 + if(found_entries is None): found_entries = set() found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) + vars._actions = vars.actions + vars._prompt = vars.prompt + if(vars.dynamicscan): + vars._actions = vars._actions.copy() + # Submit input text to generator try: - if(vars.dynamicscan): - raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") - + context = np.tile(np.uint32(txt), (vars.numseqs, 1)) soft_tokens = tpumtjgetsofttokens() - genout = tpool.execute( - tpu_mtj_backend.infer, - np.uint32(txt), - gen_len = maximum-minimum+1, - temp=vars.temp, - top_p=vars.top_p, - top_k=vars.top_k, - tfs=vars.tfs, - numseqs=vars.numseqs, - repetition_penalty=vars.rep_pen, - soft_embeddings=vars.sp, - soft_tokens=soft_tokens, - ) + global past + past = np.empty((vars.numseqs, 0), dtype=np.uint32) + + while(True): + genout, n_generated, regeneration_required, halt = tpool.execute( + tpu_mtj_backend.infer, + context, + gen_len = maximum-minimum+1, + temp=vars.temp, + top_p=vars.top_p, + top_k=vars.top_k, + tfs=vars.tfs, + numseqs=vars.numseqs, + repetition_penalty=vars.rep_pen, + soft_embeddings=vars.sp, + soft_tokens=soft_tokens, + excluded_world_info=found_entries, + ) + + past = np.pad(past, ((0, 0), (0, n_generated))) + for r in range(vars.numseqs): + for c in range(vars.lua_koboldbridge.generated_cols): + assert vars.lua_koboldbridge.generated[r+1][c+1] is not None + past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] + + if(halt or not regeneration_required): + break + + encoded = [] + for i in range(vars.numseqs): + txt = tokenizer.decode(past[i]) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + found_entries[i].update(_found_entries) + txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) + encoded.append(np.array(txt, dtype=np.uint32)) + max_length = len(max(encoded, key=len)) + encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) + context = np.concatenate( + ( + encoded, + past, + ), + axis=-1, + ) except Exception as e: if(issubclass(type(e), lupa.LuaError)): @@ -2931,10 +3010,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr) set_aibusy(0) return - + for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) - vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) + vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i]) + genout = past execute_outmod() if(vars.lua_koboldbridge.regeneration_required): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index c2caf270..9cd49a12 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1,5 +1,5 @@ import multiprocessing -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import progressbar import time import os @@ -20,6 +20,10 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor params: Dict[str, Any] = {} +def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]: + raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") + + def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) i = 0 @@ -235,13 +239,13 @@ class PenalizingCausalTransformer(CausalTransformer): def generate_initial_inner(context, ctx_length): # Give the initial context to the transformer transformer = CausalTransformerShard(config) - def generate_initial_scan_fn(sequence_index, _): - _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) + def generate_initial_scan_fn(sequence_index, c): + _, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings) generated_index = config["seq"] # Add that information to generate_loop_fn's starting state initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state return sequence_index+1, initial_state - _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) + _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs) sample_key = initial_states[-1][0] initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs)) return initial_states, sample_key @@ -307,7 +311,8 @@ class PenalizingCausalTransformer(CausalTransformer): out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): + def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): + assert excluded_world_info is not None assert not return_logits assert gen_length.ndim == 1 assert soft_embeddings is not None @@ -318,24 +323,34 @@ class PenalizingCausalTransformer(CausalTransformer): numseqs_aux = batch_xmap(_numseqs_aux) sample_data = [ [ - np.pad(ctx[0], (0, params["seq"]), constant_values=pad_token_id), + np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id), params["seq"], None, np.empty((), dtype=np.uint32), ] - for _ in range(numseqs) + for i in range(numseqs) ] repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) + n_generated = 0 + regeneration_required = False + halt = False generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) sample_key = np.asarray(sample_key[0, 0]) - for _ in range(gen_length[0].item()): + while True: generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) for i in range(numseqs): - sample_data[i][2] = np.array(generate_data[0][i][0, 0], copy=True) + sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True) sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) for i in range(numseqs): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) - return sample_data, sample_key + n_generated += 1 + if use_callback: + excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated) + if regeneration_required or halt: + break + else: + break + return sample_data, n_generated, regeneration_required, halt def infer( @@ -349,15 +364,18 @@ def infer( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, -) -> List[str]: + excluded_world_info = None, + use_callback=True, +) -> Tuple[List[np.array], int, bool, bool]: + assert excluded_world_info is not None maps.thread_resources.env = thread_resources_env total_batch = 1 tokens = context if(soft_tokens is not None): - tokens = np.uint32(np.concatenate((soft_tokens, tokens))) - provided_ctx = tokens.shape[0] + tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1)) + provided_ctx = tokens.shape[-1] pad_amount = seq - provided_ctx - padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) + padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id) batched_tokens = np.array([padded_tokens] * total_batch) samples = [] generator_params = { @@ -374,10 +392,12 @@ def infer( numseqs, generator_params, soft_embeddings=soft_embeddings, - )[0] - for out in output: + excluded_world_info=excluded_world_info, + use_callback=use_callback, + ) + for out in output[0]: samples.append(out[0][params["seq"] : params["seq"] + gen_len]) - return samples + return (samples,) + output[1:] def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: @@ -405,32 +425,25 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) jax.host_count = jax.process_count jax.host_id = jax.process_index - while True: - print("Connecting to your Colab instance's TPU", flush=True) - spinner = multiprocessing.Process(target=show_spinner, args=()) - spinner.start() - colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] - url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' - requests.post(url) - spinner.terminate() - print() - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] + print("Connecting to your Colab instance's TPU", flush=True) + spinner = multiprocessing.Process(target=show_spinner, args=()) + spinner.start() + colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] + url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' + requests.post(url) + spinner.terminate() + print() + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - cores_per_replica = params["cores_per_replica"] - seq = params["seq"] - params["optimizer"] = optax.scale(0) - mesh_shape = (1, cores_per_replica) - try: - devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) - except RuntimeError as e: - if "DEADLINE_EXCEEDED" not in str(e): - raise e - continue - thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) - maps.thread_resources.env = thread_resources_env - tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') - break + cores_per_replica = params["cores_per_replica"] + seq = params["seq"] + params["optimizer"] = optax.scale(0) + mesh_shape = (1, cores_per_replica) + devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) + thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) + maps.thread_resources.env = thread_resources_env + tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') global shard_xmap, batch_xmap shard_xmap = __shard_xmap()