Add TPU support for dynamic WI scan and generation modifiers

This commit is contained in:
Gnome Ann
2022-01-14 21:39:02 -05:00
parent 0bef92419b
commit 932c393d6a
2 changed files with 155 additions and 63 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib
import traceback
import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
import html
@ -1001,6 +1001,46 @@ else:
)
return soft_tokens
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
@ -1013,6 +1053,7 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1020,12 +1061,14 @@ else:
soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),),
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"use_callback": False,
"gen_len": 1,
"numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
},
).start()
@ -2890,32 +2933,68 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator
try:
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
global past
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
while(True):
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer,
context,
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@ -2931,10 +3010,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
genout = past
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):