Add soft prompt support to TPU backend
This commit is contained in:
parent
a60e7d3310
commit
e068aa9f26
30
aiserver.py
30
aiserver.py
|
@ -108,6 +108,7 @@ class vars:
|
||||||
loadselect = "" # Temporary storage for story filename to load
|
loadselect = "" # Temporary storage for story filename to load
|
||||||
spselect = "" # Temporary storage for soft prompt filename to load
|
spselect = "" # Temporary storage for soft prompt filename to load
|
||||||
sp = None # Current soft prompt tensor (as a NumPy array)
|
sp = None # Current soft prompt tensor (as a NumPy array)
|
||||||
|
sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt
|
||||||
svowname = "" # Filename that was flagged for overwrite confirm
|
svowname = "" # Filename that was flagged for overwrite confirm
|
||||||
saveow = False # Whether or not overwrite confirm has been displayed
|
saveow = False # Whether or not overwrite confirm has been displayed
|
||||||
genseqs = [] # Temporary storage for generated sequences
|
genseqs = [] # Temporary storage for generated sequences
|
||||||
|
@ -700,6 +701,8 @@ else:
|
||||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||||
|
vars.allowsp = True
|
||||||
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
|
||||||
# Set up Flask routes
|
# Set up Flask routes
|
||||||
|
@ -1684,10 +1687,17 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
|
|
||||||
# Submit input text to generator
|
# Submit input text to generator
|
||||||
try:
|
try:
|
||||||
if(vars.sp is not None):
|
|
||||||
raise ValueError("Softprompts are not supported by the TPU backend yet")
|
|
||||||
if(vars.dynamicscan):
|
if(vars.dynamicscan):
|
||||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||||
|
|
||||||
|
soft_tokens = None
|
||||||
|
if(vars.sp is not None):
|
||||||
|
soft_tokens = np.arange(
|
||||||
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||||
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||||
|
dtype=np.uint32
|
||||||
|
)
|
||||||
|
|
||||||
genout = tpu_mtj_backend.infer(
|
genout = tpu_mtj_backend.infer(
|
||||||
txt,
|
txt,
|
||||||
gen_len = maximum-minimum+1,
|
gen_len = maximum-minimum+1,
|
||||||
|
@ -1697,6 +1707,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
tfs=vars.tfs,
|
tfs=vars.tfs,
|
||||||
numseqs=vars.numseqs,
|
numseqs=vars.numseqs,
|
||||||
repetition_penalty=vars.rep_pen,
|
repetition_penalty=vars.rep_pen,
|
||||||
|
soft_embeddings=vars.sp,
|
||||||
|
soft_tokens=soft_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2525,6 +2537,7 @@ def loadRequest(loadpath, filename=None):
|
||||||
def spRequest(filename):
|
def spRequest(filename):
|
||||||
if(len(filename) == 0):
|
if(len(filename) == 0):
|
||||||
vars.sp = None
|
vars.sp = None
|
||||||
|
vars.sp_length = 0
|
||||||
return
|
return
|
||||||
|
|
||||||
global np
|
global np
|
||||||
|
@ -2548,6 +2561,19 @@ def spRequest(filename):
|
||||||
tensor = np.float32(tensor)
|
tensor = np.float32(tensor)
|
||||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||||
|
|
||||||
|
vars.sp_length = tensor.shape[0]
|
||||||
|
|
||||||
|
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||||
|
rows = tensor.shape[0]
|
||||||
|
padding_amount = -(rows % -tpu_mtj_backend.params["cores_per_replica"])
|
||||||
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||||
|
tensor = tensor.reshape(
|
||||||
|
tpu_mtj_backend.params["cores_per_replica"],
|
||||||
|
-1,
|
||||||
|
tpu_mtj_backend.params["d_model"],
|
||||||
|
)
|
||||||
|
vars.sp = tensor
|
||||||
|
else:
|
||||||
vars.sp = torch.from_numpy(tensor)
|
vars.sp = torch.from_numpy(tensor)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
import progressbar
|
import progressbar
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
@ -155,14 +155,14 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
def generate(state, key, ctx, ctx_length, aux, sampler_options):
|
def generate(state, key, ctx, ctx_length, aux, sampler_options, soft_embeddings=None):
|
||||||
gen_length = self.gen_length
|
gen_length = self.gen_length
|
||||||
# These are the tokens that we don't want the AI to ever write
|
# These are the tokens that we don't want the AI to ever write
|
||||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||||
def generate_sample(context, ctx_length, aux):
|
def generate_sample(context, ctx_length, aux):
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
transformer = CausalTransformerShard(config)
|
transformer = CausalTransformerShard(config)
|
||||||
_, initial_state = transformer.generate_initial(context, ctx_length)
|
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
||||||
# The "generated" array will contain the tokens from the
|
# The "generated" array will contain the tokens from the
|
||||||
# context as well as the tokens picked by the sampler at
|
# context as well as the tokens picked by the sampler at
|
||||||
# each stage, padded with a bunch of 50256s, so we know
|
# each stage, padded with a bunch of 50256s, so we know
|
||||||
|
@ -185,7 +185,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
# how strongly it thinks each of the 50257 tokens in its
|
# how strongly it thinks each of the 50257 tokens in its
|
||||||
# vocabulary should be appended to the context, followed
|
# vocabulary should be appended to the context, followed
|
||||||
# by 143 apparently useless columns ???)
|
# by 143 apparently useless columns ???)
|
||||||
logits, new_state = transformer.generate_once(next_token, decode_state)
|
logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings)
|
||||||
# Verify that logits does indeed have that many rows and
|
# Verify that logits does indeed have that many rows and
|
||||||
# columns (if you get an error here, pray for mercy)
|
# columns (if you get an error here, pray for mercy)
|
||||||
assert logits.shape == (1, config["n_vocab"])
|
assert logits.shape == (1, config["n_vocab"])
|
||||||
|
@ -243,8 +243,8 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
return final_state, outputs
|
return final_state, outputs
|
||||||
generate_fn = hk.transform(generate_sample).apply
|
generate_fn = hk.transform(generate_sample).apply
|
||||||
return generate_fn(state["params"], key, ctx, ctx_length, aux)
|
return generate_fn(state["params"], key, ctx, ctx_length, aux)
|
||||||
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'})
|
self.generate_xmap = jax.experimental.maps.xmap(fun=generate, in_axes=(["shard", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["batch", ...], ["shard", ...]), out_axes=["batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'})
|
||||||
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False):
|
def generate(self, ctx, ctx_length, gen_length, sampler_options, return_logits=False, soft_embeddings=None):
|
||||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
batch_size = ctx.shape[0]
|
batch_size = ctx.shape[0]
|
||||||
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
|
aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
|
||||||
|
@ -257,19 +257,33 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
ctx,
|
ctx,
|
||||||
np.array(ctx_length, dtype=np.uint32),
|
np.array(ctx_length, dtype=np.uint32),
|
||||||
aux,
|
aux,
|
||||||
sampler_options
|
sampler_options,
|
||||||
|
soft_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def infer(context, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, repetition_penalty=1.0, numseqs=1, gen_len=80) -> List[str]:
|
def infer(
|
||||||
|
context: str,
|
||||||
|
top_p=0.9,
|
||||||
|
temp=0.5,
|
||||||
|
top_k=0,
|
||||||
|
tfs=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
numseqs=1,
|
||||||
|
gen_len=80,
|
||||||
|
soft_embeddings: Optional[np.array] = None,
|
||||||
|
soft_tokens: Optional[np.array] = None,
|
||||||
|
) -> List[str]:
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
total_batch = numseqs
|
total_batch = numseqs
|
||||||
tokens = tokenizer.encode(context, max_length=params["seq"], truncation=True)
|
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - soft_tokens.shape[0] if soft_tokens is not None else 0, truncation=True))
|
||||||
provided_ctx = len(tokens)
|
if(soft_tokens is not None):
|
||||||
|
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||||
|
provided_ctx = tokens.shape[0]
|
||||||
pad_amount = seq - provided_ctx
|
pad_amount = seq - provided_ctx
|
||||||
padded_tokens = np.pad(np.asarray(tokens, dtype=np.uint32), ((pad_amount, 0),), constant_values=pad_token_id)
|
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
|
length = np.ones(total_batch, dtype=np.uint32) * provided_ctx
|
||||||
samples = []
|
samples = []
|
||||||
batched_generator_params = {
|
batched_generator_params = {
|
||||||
"temp": temp * np.ones(total_batch),
|
"temp": temp * np.ones(total_batch),
|
||||||
|
@ -278,7 +292,13 @@ def infer(context, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, repetition_penalty=1.0
|
||||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
||||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
||||||
}
|
}
|
||||||
output = network.generate(batched_tokens, length, gen_len, batched_generator_params)
|
output = network.generate(
|
||||||
|
batched_tokens,
|
||||||
|
length,
|
||||||
|
gen_len,
|
||||||
|
batched_generator_params,
|
||||||
|
soft_embeddings=soft_embeddings,
|
||||||
|
)
|
||||||
decoded_tokens = output[1][0]
|
decoded_tokens = output[1][0]
|
||||||
for o in decoded_tokens[:, :, 0]:
|
for o in decoded_tokens[:, :, 0]:
|
||||||
samples.append(tokenizer.decode(o))
|
samples.append(tokenizer.decode(o))
|
||||||
|
|
Loading…
Reference in New Issue