Merge pull request #58 from VE-FORBRYDERNE/xmap

Dynamic TPU backend xmaps
This commit is contained in:
henk717
2022-01-15 16:20:58 +01:00
committed by GitHub
2 changed files with 330 additions and 138 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib import contextlib
import traceback import traceback
import threading import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests import requests
import html import html
@ -993,7 +993,7 @@ else:
-1, -1,
tpu_mtj_backend.params["d_model"], tpu_mtj_backend.params["d_model"],
) )
vars.sp = tensor vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange( soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -1001,6 +1001,49 @@ else:
) )
return soft_tokens return soft_tokens
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer. # If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"): if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
@ -1013,6 +1056,8 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.load_model(vars.custmodpth) tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1020,12 +1065,14 @@ else:
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer, target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),), args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={ kwargs={
"soft_embeddings": vars.sp, "soft_embeddings": vars.sp,
"soft_tokens": soft_tokens, "soft_tokens": soft_tokens,
"use_callback": False,
"gen_len": 1, "gen_len": 1,
"numseqs": vars.numseqs, "numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
}, },
).start() ).start()
@ -2890,32 +2937,69 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend # Send text to TPU mesh transformer backend
#==================================================================# #==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None): def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None): if(found_entries is None):
found_entries = set() found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator # Submit input text to generator
try: try:
if(vars.dynamicscan): context = np.tile(np.uint32(txt), (vars.numseqs, 1))
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute( global past
tpu_mtj_backend.infer, past = np.empty((vars.numseqs, 0), dtype=np.uint32)
np.uint32(txt),
gen_len = maximum-minimum+1, while(True):
temp=vars.temp, genout, n_generated, regeneration_required, halt = tpool.execute(
top_p=vars.top_p, tpu_mtj_backend.infer,
top_k=vars.top_k, context,
tfs=vars.tfs, gen_len = maximum-minimum+1,
numseqs=vars.numseqs, temp=vars.temp,
repetition_penalty=vars.rep_pen, top_p=vars.top_p,
soft_embeddings=vars.sp, top_k=vars.top_k,
soft_tokens=soft_tokens, tfs=vars.tfs,
) numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
print("(regeneration triggered)")
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):
@ -2933,8 +3017,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
return return
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) genout = past
execute_outmod() execute_outmod()
if(vars.lua_koboldbridge.regeneration_required): if(vars.lua_koboldbridge.regeneration_required):
@ -4005,7 +4089,7 @@ def spRequest(filename):
-1, -1,
tpu_mtj_backend.params["d_model"], tpu_mtj_backend.params["d_model"],
) )
vars.sp = np.float32(tensor) vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else: else:
vars.sp = torch.from_numpy(tensor) vars.sp = torch.from_numpy(tensor)

View File

@ -1,5 +1,5 @@
import multiprocessing import multiprocessing
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar import progressbar
import time import time
import os import os
@ -20,6 +20,13 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
def show_spinner(): def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')]) bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
i = 0 i = 0
@ -28,6 +35,31 @@ def show_spinner():
time.sleep(0.1) time.sleep(0.1)
i += 1 i += 1
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
def __move_xmap(f: __F, out_axis: str) -> __F:
return maps.xmap(
f,
in_axes=(["shard", ...], ["batch", ...]),
out_axes=[out_axis, ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def __shard_xmap(batch_dim=1):
xmap = __move_xmap(lambda s, b: s, "shard")
def inner(x: __T) -> __T:
return xmap(x, np.empty(batch_dim))
return inner
def __batch_xmap(shard_dim=1):
xmap = __move_xmap(lambda s, b: b, "batch")
def inner(x: __T) -> __T:
return xmap(np.empty(shard_dim), x)
return inner
def apply_repetition_penalty(logits, tokens, repetition_penalty): def apply_repetition_penalty(logits, tokens, repetition_penalty):
''' '''
This gets called by generate_loop_fn to apply repetition penalty This gets called by generate_loop_fn to apply repetition penalty
@ -38,19 +70,20 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
# logits array; e.g. # logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens) penalty_logits = np.take(logits, tokens)
# Divide positive values by repetition_penalty and multiply negative # Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described # values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens # this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong) # with negative logits to become more likely, which is obviously wrong)
penalty_logits = jnp.where( penalty_logits = np.where(
penalty_logits > 0, penalty_logits > 0,
penalty_logits/repetition_penalty, penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty, penalty_logits*repetition_penalty,
) )
# Finally, put those penalized logit values back into their original # Finally, put those penalized logit values back into their original
# positions in the logits array # positions in the logits array
return logits.at[tokens].set(penalty_logits) logits[tokens] = penalty_logits
return logits
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
''' '''
@ -66,15 +99,16 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# in the sorted logits array we want to remove and False for ones # in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be # we want to keep, in this case the first top_k elements will be
# False and the rest will be True # False and the rest will be True
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k sorted_indices_to_remove = np.arange(len(logits)) >= top_k
# Unsort the logits array back to its original configuration and # Unsort the logits array back to its original configuration and
# remove tokens we need to remove # remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val( _, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits), np.argsort(-logits),
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) if top_k > 0:
logits = top_k_filter(logits)
# Top-p (after sorting the remaining tokens again in descending order of # Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability # logit, remove the ones that have cumulative softmax probability
# greater than p) # greater than p)
@ -83,109 +117,167 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# with e (Euler's number) to the power of that element, and divide # with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the # each element of the new array by the sum of the elements in the
# new array # new array
sorted_logits = -jnp.sort(-logits) sorted_logits = -np.sort(-logits)
probabilities = jax.nn.softmax(sorted_logits) probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate cumulative_probabilities as the prefix-sum array of # Calculate cumulative_probabilities as the prefix-sum array of
# probabilities # probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) cumulative_probabilities = np.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher # We want to remove tokens with cumulative probability higher
# than top_p # than top_p
sorted_indices_to_remove = cumulative_probabilities > top_p sorted_indices_to_remove = cumulative_probabilities > top_p
# Don't ever remove the token with the highest logit, even if # Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p # the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) sorted_indices_to_remove[0] = False
# Unsort and remove # Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val( _, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits), np.argsort(-logits),
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) if top_p < 1.0:
logits = top_p_filter(logits)
# Tail free sampling (basically top-p a second time on remaining tokens # Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite # except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the # differences of the softmax probabilities" instead of just the
# cumulative softmax probabilities) # cumulative softmax probabilities)
def tail_free_filter(logits): def tail_free_filter(logits):
# Sort in descending order # Sort in descending order
sorted_logits = -jnp.sort(-logits) sorted_logits = -np.sort(-logits)
# Softmax again # Softmax again
probabilities = jax.nn.softmax(sorted_logits) probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate the second finite differences of that array (i.e. # Calculate the second finite differences of that array (i.e.
# calculate the difference array and then calculate the difference # calculate the difference array and then calculate the difference
# array of the difference array) # array of the difference array)
d2 = jnp.diff(jnp.diff(probabilities)) d2 = np.diff(np.diff(probabilities))
# Get the absolute values of all those second finite differences # Get the absolute values of all those second finite differences
d2 = jnp.abs(d2) d2 = np.abs(d2)
# Normalize (all elements in the array are divided by the sum of the # Normalize (all elements in the array are divided by the sum of the
# array's elements) # array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True) d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array # Get the prefix-sum array
cumulative_d2 = jnp.cumsum(d2, axis=-1) cumulative_d2 = np.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute # We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value # second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > tfs sorted_indices_to_remove = cumulative_d2 > tfs
# Don't remove the token with the highest logit # Don't remove the token with the highest logit
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) sorted_indices_to_remove[0] = False
# Since the d2 array has two fewer elements than the logits array, # Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end # we'll add two extra Trues to the end
sorted_indices_to_remove = jnp.pad( sorted_indices_to_remove = np.pad(
sorted_indices_to_remove, sorted_indices_to_remove,
(0, 2), (0, 2),
constant_values=True, constant_values=True,
) )
# Unsort and remove # Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val( _, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits), np.argsort(-logits),
sorted_indices_to_remove, sorted_indices_to_remove,
) )
return jnp.where(indices_to_remove, -jnp.inf, logits) return np.where(indices_to_remove, -np.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) if tfs < 1.0:
logits = tail_free_filter(logits)
# Temperature (just divide the logits by the temperature) # Temperature (just divide the logits by the temperature)
def temp_filter(logits): logits /= temp
return logits / temp
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
# Finally, pick one token using the softmax thingy again (it gives # Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a # an array whose elements sum to 1 so it can be used nicely as a
# probability distribution) # probability distribution)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] return jax.random.categorical(key, logits, -1).astype(np.uint32)
pad_token_id = 50256 pad_token_id = 50256
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
numseqs = numseqs_aux.shape[0]
gi = data[0][1]
def sample_loop_fn(carry):
generated, generated_index, logits, _ = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key, num=2)
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
logits = apply_repetition_penalty(
logits,
generated,
repetition_penalty
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
logits[badwords] = -np.inf
# Use the sampler (kobold_sample) to pick one token
# based on the logits array as a 0D uint32 array
# (higher logit means higher probability of being
# picked, non-linearly)
next_token = kobold_sample(
sample_key,
logits,
**sampler_options,
)
# Remember what token was picked
generated[generated_index] = next_token
generated_index += 1
# Re-pack the current sample_loop_fn's state so we can
# get back the same variables the next time
carry[0][0] = [generated, generated_index, logits, next_token]
carry[0].append(carry[0].pop(0))
return carry[0], new_key
# return jax.lax.while_loop(
# lambda carry: carry[0][0][1] == gi,
# sample_loop_fn,
# (data, key),
# )
carry = (data, key)
while carry[0][0][1] == gi:
carry = sample_loop_fn(carry)
return carry
class PenalizingCausalTransformer(CausalTransformer): class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config): def __init__(self, config):
# Initialize # Initialize
super().__init__(config) super().__init__(config)
def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
numseqs = numseqs_aux.shape[0] numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write @hk.transform
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) def generate_initial_inner(context, ctx_length):
def generate_sample(context, ctx_length):
# Give the initial context to the transformer # Give the initial context to the transformer
transformer = CausalTransformerShard(config) transformer = CausalTransformerShard(config)
def generate_initial_scan_fn(sequence_index, _): def generate_initial_scan_fn(sequence_index, c):
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) _, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings)
# The "generated" array will contain the tokens from the
# context as well as the tokens picked by the sampler at
# each stage, padded with a bunch of 50256s, so we know
# which tokens have to be repetition penalized
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
generated_index = config["seq"] generated_index = config["seq"]
# Add that information to generate_loop_fn's starting state # Add that information to generate_loop_fn's starting state
initial_state = (generated, generated_index, sequence_index) + initial_state initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
return sequence_index+1, initial_state return sequence_index+1, initial_state
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs)
sample_key = initial_states[-1][0] sample_key = initial_states[-1][0]
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs))
# Get repetition penalty from the arguments return initial_states, sample_key
repetition_penalty = sampler_options.pop('repetition_penalty', None) return generate_initial_inner.apply(state["params"], key, ctx, ctx_length)
self.generate_initial_xmap = jax.experimental.maps.xmap(
fun=generate_initial,
in_axes=(
["shard", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["shard", ...],
),
out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def generate_once(data, state, numseqs_aux, soft_embeddings=None):
numseqs = numseqs_aux.shape[0]
@hk.without_apply_rng
@hk.transform
def generate_once_inner():
gi = data[0][1]
# Give the initial context to the transformer
transformer = CausalTransformerShard(config)
# This is the main generation loop # This is the main generation loop
def generate_loop_fn(carry): def generate_loop_fn(carry):
# Unpack current generate_loop_fn state # Unpack current generate_loop_fn state
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] _, generated_index, sequence_index, next_token, decode_state = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key)
# Give the context to the model and get the logits it # Give the context to the model and get the logits it
# spits out # spits out
# (a 2D array with 1 row and 50400 columns representing # (a 2D array with 1 row and 50400 columns representing
@ -196,75 +288,78 @@ class PenalizingCausalTransformer(CausalTransformer):
# Verify that logits does indeed have that many rows and # Verify that logits does indeed have that many rows and
# columns (if you get an error here, pray for mercy) # columns (if you get an error here, pray for mercy)
assert logits.shape == (1, config["n_vocab"]) assert logits.shape == (1, config["n_vocab"])
assert logits.dtype == jnp.float32
# Flatten it into a 1D array to make it easier to use # Flatten it into a 1D array to make it easier to use
logits = logits[0] logits = logits[0]
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
if repetition_penalty is not None:
logits = apply_repetition_penalty(
logits,
generated,
repetition_penalty
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
logits = logits.at[self.badwords].set(-jnp.inf)
# Use the sampler (kobold_sample) to pick one token
# based on the logits array as a 1D array with 1 element
# (higher logit means higher probability of being
# picked, non-linearly)
next_token = kobold_sample(
sample_key,
logits,
**sampler_options,
)
# Remember what token was picked
generated = generated.at[generated_index].set(next_token[0])
generated_index += 1
# Re-pack the current generate_loop_fn's state so we can # Re-pack the current generate_loop_fn's state so we can
# get back the same variables the next time # get back the same variables the next time
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) generated_index += 1
carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state]
carry[0].append(carry[0].pop(0)) carry[0].append(carry[0].pop(0))
return carry[0], new_key return carry[0],
final_state = jax.lax.while_loop( return jax.lax.while_loop(
lambda carry: carry[0][0][1] - config["seq"] < gen_length, lambda carry: carry[0][0][1] == gi,
generate_loop_fn, generate_loop_fn,
(initial_states, sample_key), (data,),
) )
return final_state return generate_once_inner.apply(state["params"])
generate_fn = hk.transform(generate_sample).apply self.generate_once_xmap = jax.experimental.maps.xmap(
return generate_fn(state["params"], key, ctx, ctx_length) fun=generate_once,
self.generate_xmap = jax.experimental.maps.xmap(
fun=generate,
in_axes=( in_axes=(
["shard", "batch", ...],
["shard", ...], ["shard", ...],
["batch", ...], ["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["shard", ...], ["shard", ...],
), ),
out_axes=["shard", "batch", ...], out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'}, axis_resources={'shard': 'mp', 'batch': 'dp'},
) )
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
assert excluded_world_info is not None
assert not return_logits assert not return_logits
assert gen_length.ndim == 1
assert soft_embeddings is not None
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
self.batch_size = batch_size self.batch_size = batch_size
return self.generate_xmap( _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
self.state, numseqs_aux = batch_xmap(_numseqs_aux)
jnp.array(key.take(batch_size)), sample_data = [
ctx, [
np.array(ctx_length, dtype=np.uint32), np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id),
np.array(gen_length, dtype=np.uint32), params["seq"],
np.empty((batch_size, numseqs), dtype=np.uint8), None,
sampler_options, np.empty((), dtype=np.uint32),
soft_embeddings, ]
) for i in range(numseqs)
]
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
n_generated = 0
regeneration_required = False
halt = False
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
sample_key = np.asarray(sample_key[0, 0])
while True:
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
for i in range(numseqs):
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
if use_callback:
logits = np.float32(tuple(d[2] for d in sample_data))
logits = warper_callback(logits)
for i in range(numseqs):
sample_data[i][2] = logits[i]
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
n_generated += 1
for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
if use_callback:
generated = np.uint32(tuple(d[0] for d in sample_data))
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
if regeneration_required or halt:
break
else:
break
return sample_data, n_generated, regeneration_required, halt
def infer( def infer(
@ -278,35 +373,40 @@ def infer(
gen_len=80, gen_len=80,
soft_embeddings: Optional[np.array] = None, soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None, soft_tokens: Optional[np.array] = None,
) -> List[str]: excluded_world_info = None,
use_callback=True,
) -> Tuple[List[np.array], int, bool, bool]:
assert excluded_world_info is not None
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
total_batch = 1 total_batch = 1
tokens = context tokens = context
if(soft_tokens is not None): if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens))) tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1))
provided_ctx = tokens.shape[0] provided_ctx = tokens.shape[-1]
pad_amount = seq - provided_ctx pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
batched_tokens = np.array([padded_tokens] * total_batch) batched_tokens = np.array([padded_tokens] * total_batch)
samples = [] samples = []
batched_generator_params = { generator_params = {
"temp": temp * np.ones(total_batch), "temp": float(temp),
"top_p": top_p * np.ones(total_batch), "top_p": float(top_p),
"tfs": tfs * np.ones(total_batch), "tfs": float(tfs),
"repetition_penalty": repetition_penalty * np.ones(total_batch), "repetition_penalty": float(repetition_penalty),
"top_k": np.full(total_batch, top_k, dtype=np.uint32) "top_k": int(top_k),
} }
output = network.generate( output = network.generate(
batched_tokens, batched_tokens,
np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * provided_ctx,
np.ones(total_batch, dtype=np.uint32) * gen_len, np.ones(total_batch, dtype=np.uint32) * gen_len,
numseqs, numseqs,
batched_generator_params, generator_params,
soft_embeddings=soft_embeddings, soft_embeddings=soft_embeddings,
)[0] excluded_world_info=excluded_world_info,
for o in output: use_callback=use_callback,
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) )
return samples for out in output[0]:
samples.append(out[0][params["seq"] : params["seq"] + gen_len])
return (samples,) + output[1:]
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
@ -354,6 +454,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
global badwords
# These are the tokens that we don't want the AI to ever write
badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
if not path.endswith("/"): if not path.endswith("/"):
path += "/" path += "/"