Merge pull request #58 from VE-FORBRYDERNE/xmap

Dynamic TPU backend xmaps
This commit is contained in:
henk717 2022-01-15 16:20:58 +01:00 committed by GitHub
commit 9bcc24c07e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 330 additions and 138 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib
import traceback
import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
import html
@ -993,7 +993,7 @@ else:
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -1001,6 +1001,49 @@ else:
)
return soft_tokens
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
@ -1013,6 +1056,8 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1020,12 +1065,14 @@ else:
soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),),
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"use_callback": False,
"gen_len": 1,
"numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
},
).start()
@ -2890,32 +2937,69 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator
try:
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
global past
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
while(True):
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer,
context,
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
print("(regeneration triggered)")
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@ -2931,10 +3015,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
genout = past
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
@ -4005,7 +4089,7 @@ def spRequest(filename):
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = np.float32(tensor)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)

View File

@ -1,5 +1,5 @@
import multiprocessing
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
import progressbar
import time
import os
@ -20,6 +20,13 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
params: Dict[str, Any] = {}
def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
i = 0
@ -28,6 +35,31 @@ def show_spinner():
time.sleep(0.1)
i += 1
__F = TypeVar("__F", bound=Callable)
__T = TypeVar("__T")
def __move_xmap(f: __F, out_axis: str) -> __F:
return maps.xmap(
f,
in_axes=(["shard", ...], ["batch", ...]),
out_axes=[out_axis, ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def __shard_xmap(batch_dim=1):
xmap = __move_xmap(lambda s, b: s, "shard")
def inner(x: __T) -> __T:
return xmap(x, np.empty(batch_dim))
return inner
def __batch_xmap(shard_dim=1):
xmap = __move_xmap(lambda s, b: b, "batch")
def inner(x: __T) -> __T:
return xmap(np.empty(shard_dim), x)
return inner
def apply_repetition_penalty(logits, tokens, repetition_penalty):
'''
This gets called by generate_loop_fn to apply repetition penalty
@ -38,19 +70,20 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
penalty_logits = np.take(logits, tokens)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
penalty_logits = jnp.where(
penalty_logits = np.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
logits[tokens] = penalty_logits
return logits
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
'''
@ -66,15 +99,16 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# in the sorted logits array we want to remove and False for ones
# we want to keep, in this case the first top_k elements will be
# False and the rest will be True
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
sorted_indices_to_remove = np.arange(len(logits)) >= top_k
# Unsort the logits array back to its original configuration and
# remove tokens we need to remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
np.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
return np.where(indices_to_remove, -np.inf, logits)
if top_k > 0:
logits = top_k_filter(logits)
# Top-p (after sorting the remaining tokens again in descending order of
# logit, remove the ones that have cumulative softmax probability
# greater than p)
@ -83,109 +117,167 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
# with e (Euler's number) to the power of that element, and divide
# each element of the new array by the sum of the elements in the
# new array
sorted_logits = -jnp.sort(-logits)
probabilities = jax.nn.softmax(sorted_logits)
sorted_logits = -np.sort(-logits)
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate cumulative_probabilities as the prefix-sum array of
# probabilities
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove[0] = False
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
np.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
return np.where(indices_to_remove, -np.inf, logits)
if top_p < 1.0:
logits = top_p_filter(logits)
# Tail free sampling (basically top-p a second time on remaining tokens
# except it's the "cumulative normalized absolute second finite
# differences of the softmax probabilities" instead of just the
# cumulative softmax probabilities)
def tail_free_filter(logits):
# Sort in descending order
sorted_logits = -jnp.sort(-logits)
sorted_logits = -np.sort(-logits)
# Softmax again
probabilities = jax.nn.softmax(sorted_logits)
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
# Calculate the second finite differences of that array (i.e.
# calculate the difference array and then calculate the difference
# array of the difference array)
d2 = jnp.diff(jnp.diff(probabilities))
d2 = np.diff(np.diff(probabilities))
# Get the absolute values of all those second finite differences
d2 = jnp.abs(d2)
d2 = np.abs(d2)
# Normalize (all elements in the array are divided by the sum of the
# array's elements)
d2 = d2 / d2.sum(axis=-1, keepdims=True)
# Get the prefix-sum array
cumulative_d2 = jnp.cumsum(d2, axis=-1)
cumulative_d2 = np.cumsum(d2, axis=-1)
# We will remove the tokens with a cumulative normalized absolute
# second finite difference larger than the TFS value
sorted_indices_to_remove = cumulative_d2 > tfs
# Don't remove the token with the highest logit
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
sorted_indices_to_remove[0] = False
# Since the d2 array has two fewer elements than the logits array,
# we'll add two extra Trues to the end
sorted_indices_to_remove = jnp.pad(
sorted_indices_to_remove = np.pad(
sorted_indices_to_remove,
(0, 2),
constant_values=True,
)
# Unsort and remove
_, indices_to_remove = jax.lax.sort_key_val(
jnp.argsort(-logits),
np.argsort(-logits),
sorted_indices_to_remove,
)
return jnp.where(indices_to_remove, -jnp.inf, logits)
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
return np.where(indices_to_remove, -np.inf, logits)
if tfs < 1.0:
logits = tail_free_filter(logits)
# Temperature (just divide the logits by the temperature)
def temp_filter(logits):
return logits / temp
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
logits /= temp
# Finally, pick one token using the softmax thingy again (it gives
# an array whose elements sum to 1 so it can be used nicely as a
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
return jax.random.categorical(key, logits, -1).astype(np.uint32)
pad_token_id = 50256
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
numseqs = numseqs_aux.shape[0]
gi = data[0][1]
def sample_loop_fn(carry):
generated, generated_index, logits, _ = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key, num=2)
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
logits = apply_repetition_penalty(
logits,
generated,
repetition_penalty
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
logits[badwords] = -np.inf
# Use the sampler (kobold_sample) to pick one token
# based on the logits array as a 0D uint32 array
# (higher logit means higher probability of being
# picked, non-linearly)
next_token = kobold_sample(
sample_key,
logits,
**sampler_options,
)
# Remember what token was picked
generated[generated_index] = next_token
generated_index += 1
# Re-pack the current sample_loop_fn's state so we can
# get back the same variables the next time
carry[0][0] = [generated, generated_index, logits, next_token]
carry[0].append(carry[0].pop(0))
return carry[0], new_key
# return jax.lax.while_loop(
# lambda carry: carry[0][0][1] == gi,
# sample_loop_fn,
# (data, key),
# )
carry = (data, key)
while carry[0][0][1] == gi:
carry = sample_loop_fn(carry)
return carry
class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config):
# Initialize
super().__init__(config)
def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
numseqs = numseqs_aux.shape[0]
# These are the tokens that we don't want the AI to ever write
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
def generate_sample(context, ctx_length):
@hk.transform
def generate_initial_inner(context, ctx_length):
# Give the initial context to the transformer
transformer = CausalTransformerShard(config)
def generate_initial_scan_fn(sequence_index, _):
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
# The "generated" array will contain the tokens from the
# context as well as the tokens picked by the sampler at
# each stage, padded with a bunch of 50256s, so we know
# which tokens have to be repetition penalized
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
def generate_initial_scan_fn(sequence_index, c):
_, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings)
generated_index = config["seq"]
# Add that information to generate_loop_fn's starting state
initial_state = (generated, generated_index, sequence_index) + initial_state
initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
return sequence_index+1, initial_state
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs)
sample_key = initial_states[-1][0]
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
# Get repetition penalty from the arguments
repetition_penalty = sampler_options.pop('repetition_penalty', None)
initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs))
return initial_states, sample_key
return generate_initial_inner.apply(state["params"], key, ctx, ctx_length)
self.generate_initial_xmap = jax.experimental.maps.xmap(
fun=generate_initial,
in_axes=(
["shard", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["shard", ...],
),
out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def generate_once(data, state, numseqs_aux, soft_embeddings=None):
numseqs = numseqs_aux.shape[0]
@hk.without_apply_rng
@hk.transform
def generate_once_inner():
gi = data[0][1]
# Give the initial context to the transformer
transformer = CausalTransformerShard(config)
# This is the main generation loop
def generate_loop_fn(carry):
# Unpack current generate_loop_fn state
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
sample_key = carry[1]
# Get the pseudo-random number generator key that will
# be used by kobold_sample to randomly pick a token
sample_key, new_key = jax.random.split(sample_key)
_, generated_index, sequence_index, next_token, decode_state = carry[0][0]
# Give the context to the model and get the logits it
# spits out
# (a 2D array with 1 row and 50400 columns representing
@ -196,75 +288,78 @@ class PenalizingCausalTransformer(CausalTransformer):
# Verify that logits does indeed have that many rows and
# columns (if you get an error here, pray for mercy)
assert logits.shape == (1, config["n_vocab"])
assert logits.dtype == jnp.float32
# Flatten it into a 1D array to make it easier to use
logits = logits[0]
# Apply repetition penalty to all tokens that are
# currently inside the "generated" array
if repetition_penalty is not None:
logits = apply_repetition_penalty(
logits,
generated,
repetition_penalty
)
# Remove any tokens in the badwords list by setting
# their logits to negative infinity which effectively
# makes their probabilities of being chosen zero
logits = logits.at[self.badwords].set(-jnp.inf)
# Use the sampler (kobold_sample) to pick one token
# based on the logits array as a 1D array with 1 element
# (higher logit means higher probability of being
# picked, non-linearly)
next_token = kobold_sample(
sample_key,
logits,
**sampler_options,
)
# Remember what token was picked
generated = generated.at[generated_index].set(next_token[0])
generated_index += 1
# Re-pack the current generate_loop_fn's state so we can
# get back the same variables the next time
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
generated_index += 1
carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state]
carry[0].append(carry[0].pop(0))
return carry[0], new_key
final_state = jax.lax.while_loop(
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
return carry[0],
return jax.lax.while_loop(
lambda carry: carry[0][0][1] == gi,
generate_loop_fn,
(initial_states, sample_key),
(data,),
)
return final_state
generate_fn = hk.transform(generate_sample).apply
return generate_fn(state["params"], key, ctx, ctx_length)
self.generate_xmap = jax.experimental.maps.xmap(
fun=generate,
return generate_once_inner.apply(state["params"])
self.generate_once_xmap = jax.experimental.maps.xmap(
fun=generate_once,
in_axes=(
["shard", "batch", ...],
["shard", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["shard", ...],
),
out_axes=["shard", "batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'},
)
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
assert excluded_world_info is not None
assert not return_logits
assert gen_length.ndim == 1
assert soft_embeddings is not None
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
return self.generate_xmap(
self.state,
jnp.array(key.take(batch_size)),
ctx,
np.array(ctx_length, dtype=np.uint32),
np.array(gen_length, dtype=np.uint32),
np.empty((batch_size, numseqs), dtype=np.uint8),
sampler_options,
soft_embeddings,
)
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
numseqs_aux = batch_xmap(_numseqs_aux)
sample_data = [
[
np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id),
params["seq"],
None,
np.empty((), dtype=np.uint32),
]
for i in range(numseqs)
]
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
n_generated = 0
regeneration_required = False
halt = False
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
sample_key = np.asarray(sample_key[0, 0])
while True:
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
for i in range(numseqs):
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
if use_callback:
logits = np.float32(tuple(d[2] for d in sample_data))
logits = warper_callback(logits)
for i in range(numseqs):
sample_data[i][2] = logits[i]
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
n_generated += 1
for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
if use_callback:
generated = np.uint32(tuple(d[0] for d in sample_data))
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
if regeneration_required or halt:
break
else:
break
return sample_data, n_generated, regeneration_required, halt
def infer(
@ -278,35 +373,40 @@ def infer(
gen_len=80,
soft_embeddings: Optional[np.array] = None,
soft_tokens: Optional[np.array] = None,
) -> List[str]:
excluded_world_info = None,
use_callback=True,
) -> Tuple[List[np.array], int, bool, bool]:
assert excluded_world_info is not None
maps.thread_resources.env = thread_resources_env
total_batch = 1
tokens = context
if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
provided_ctx = tokens.shape[0]
tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1))
provided_ctx = tokens.shape[-1]
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
batched_tokens = np.array([padded_tokens] * total_batch)
samples = []
batched_generator_params = {
"temp": temp * np.ones(total_batch),
"top_p": top_p * np.ones(total_batch),
"tfs": tfs * np.ones(total_batch),
"repetition_penalty": repetition_penalty * np.ones(total_batch),
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
generator_params = {
"temp": float(temp),
"top_p": float(top_p),
"tfs": float(tfs),
"repetition_penalty": float(repetition_penalty),
"top_k": int(top_k),
}
output = network.generate(
batched_tokens,
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
np.ones(total_batch, dtype=np.uint32) * gen_len,
numseqs,
batched_generator_params,
generator_params,
soft_embeddings=soft_embeddings,
)[0]
for o in output:
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len])
return samples
excluded_world_info=excluded_world_info,
use_callback=use_callback,
)
for out in output[0]:
samples.append(out[0][params["seq"] : params["seq"] + gen_len])
return (samples,) + output[1:]
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
@ -354,6 +454,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
global badwords
# These are the tokens that we don't want the AI to ever write
badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
if not path.endswith("/"):
path += "/"