Add TPU support for dynamic WI scan and generation modifiers

This commit is contained in:
Gnome Ann 2022-01-14 21:39:02 -05:00
parent 0bef92419b
commit 932c393d6a
2 changed files with 155 additions and 63 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib import contextlib
import traceback import traceback
import threading import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests import requests
import html import html
@ -1001,6 +1001,46 @@ else:
) )
return soft_tokens return soft_tokens
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer. # If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"): if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
@ -1013,6 +1053,7 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.load_model(vars.custmodpth) tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1020,12 +1061,14 @@ else:
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer, target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),), args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={ kwargs={
"soft_embeddings": vars.sp, "soft_embeddings": vars.sp,
"soft_tokens": soft_tokens, "soft_tokens": soft_tokens,
"use_callback": False,
"gen_len": 1, "gen_len": 1,
"numseqs": vars.numseqs, "numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
}, },
).start() ).start()
@ -2890,32 +2933,68 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend # Send text to TPU mesh transformer backend
#==================================================================# #==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None): def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None): if(found_entries is None):
found_entries = set() found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator # Submit input text to generator
try: try:
if(vars.dynamicscan): context = np.tile(np.uint32(txt), (vars.numseqs, 1))
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute( global past
tpu_mtj_backend.infer, past = np.empty((vars.numseqs, 0), dtype=np.uint32)
np.uint32(txt),
gen_len = maximum-minimum+1, while(True):
temp=vars.temp, genout, n_generated, regeneration_required, halt = tpool.execute(
top_p=vars.top_p, tpu_mtj_backend.infer,
top_k=vars.top_k, context,
tfs=vars.tfs, gen_len = maximum-minimum+1,
numseqs=vars.numseqs, temp=vars.temp,
repetition_penalty=vars.rep_pen, top_p=vars.top_p,
soft_embeddings=vars.sp, top_k=vars.top_k,
soft_tokens=soft_tokens, tfs=vars.tfs,
) numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):
@ -2933,8 +3012,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
return return
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) genout = past
execute_outmod() execute_outmod()
if(vars.lua_koboldbridge.regeneration_required): if(vars.lua_koboldbridge.regeneration_required):

View File

@ -1,5 +1,5 @@
import multiprocessing import multiprocessing
from typing import Any, Callable, Dict, List, Optional, TypeVar from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar import progressbar
import time import time
import os import os
@ -20,6 +20,10 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
def show_spinner(): def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')]) bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
i = 0 i = 0
@ -235,13 +239,13 @@ class PenalizingCausalTransformer(CausalTransformer):
def generate_initial_inner(context, ctx_length): def generate_initial_inner(context, ctx_length):
# Give the initial context to the transformer # Give the initial context to the transformer
transformer = CausalTransformerShard(config) transformer = CausalTransformerShard(config)
def generate_initial_scan_fn(sequence_index, _): def generate_initial_scan_fn(sequence_index, c):
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) _, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings)
generated_index = config["seq"] generated_index = config["seq"]
# Add that information to generate_loop_fn's starting state # Add that information to generate_loop_fn's starting state
initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
return sequence_index+1, initial_state return sequence_index+1, initial_state
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs)
sample_key = initial_states[-1][0] sample_key = initial_states[-1][0]
initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs)) initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs))
return initial_states, sample_key return initial_states, sample_key
@ -307,7 +311,8 @@ class PenalizingCausalTransformer(CausalTransformer):
out_axes=["shard", "batch", ...], out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}, axis_resources={'shard': 'mp', 'batch': 'dp'},
) )
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
assert excluded_world_info is not None
assert not return_logits assert not return_logits
assert gen_length.ndim == 1 assert gen_length.ndim == 1
assert soft_embeddings is not None assert soft_embeddings is not None
@ -318,24 +323,34 @@ class PenalizingCausalTransformer(CausalTransformer):
numseqs_aux = batch_xmap(_numseqs_aux) numseqs_aux = batch_xmap(_numseqs_aux)
sample_data = [ sample_data = [
[ [
np.pad(ctx[0], (0, params["seq"]), constant_values=pad_token_id), np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id),
params["seq"], params["seq"],
None, None,
np.empty((), dtype=np.uint32), np.empty((), dtype=np.uint32),
] ]
for _ in range(numseqs) for i in range(numseqs)
] ]
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
n_generated = 0
regeneration_required = False
halt = False
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
sample_key = np.asarray(sample_key[0, 0]) sample_key = np.asarray(sample_key[0, 0])
for _ in range(gen_length[0].item()): while True:
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
for i in range(numseqs): for i in range(numseqs):
sample_data[i][2] = np.array(generate_data[0][i][0, 0], copy=True) sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
for i in range(numseqs): for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
return sample_data, sample_key n_generated += 1
if use_callback:
excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated)
if regeneration_required or halt:
break
else:
break
return sample_data, n_generated, regeneration_required, halt
def infer( def infer(
@ -349,15 +364,18 @@ def infer(
gen_len=80, gen_len=80,
soft_embeddings: Optional[np.array] = None, soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None, soft_tokens: Optional[np.array] = None,
) -> List[str]: excluded_world_info = None,
use_callback=True,
) -> Tuple[List[np.array], int, bool, bool]:
assert excluded_world_info is not None
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
total_batch = 1 total_batch = 1
tokens = context tokens = context
if(soft_tokens is not None): if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens))) tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1))
provided_ctx = tokens.shape[0] provided_ctx = tokens.shape[-1]
pad_amount = seq - provided_ctx pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
batched_tokens = np.array([padded_tokens] * total_batch) batched_tokens = np.array([padded_tokens] * total_batch)
samples = [] samples = []
generator_params = { generator_params = {
@ -374,10 +392,12 @@ def infer(
numseqs, numseqs,
generator_params, generator_params,
soft_embeddings=soft_embeddings, soft_embeddings=soft_embeddings,
)[0] excluded_world_info=excluded_world_info,
for out in output: use_callback=use_callback,
)
for out in output[0]:
samples.append(out[0][params["seq"] : params["seq"] + gen_len]) samples.append(out[0][params["seq"] : params["seq"] + gen_len])
return samples return (samples,) + output[1:]
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
@ -405,32 +425,25 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
jax.host_count = jax.process_count jax.host_count = jax.process_count
jax.host_id = jax.process_index jax.host_id = jax.process_index
while True: print("Connecting to your Colab instance's TPU", flush=True)
print("Connecting to your Colab instance's TPU", flush=True) spinner = multiprocessing.Process(target=show_spinner, args=())
spinner = multiprocessing.Process(target=show_spinner, args=()) spinner.start()
spinner.start() colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}'
url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' requests.post(url)
requests.post(url) spinner.terminate()
spinner.terminate() print()
print() config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
cores_per_replica = params["cores_per_replica"] cores_per_replica = params["cores_per_replica"]
seq = params["seq"] seq = params["seq"]
params["optimizer"] = optax.scale(0) params["optimizer"] = optax.scale(0)
mesh_shape = (1, cores_per_replica) mesh_shape = (1, cores_per_replica)
try: devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
except RuntimeError as e: maps.thread_resources.env = thread_resources_env
if "DEADLINE_EXCEEDED" not in str(e): tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
raise e
continue
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
break
global shard_xmap, batch_xmap global shard_xmap, batch_xmap
shard_xmap = __shard_xmap() shard_xmap = __shard_xmap()