Merge pull request #58 from VE-FORBRYDERNE/xmap

Dynamic TPU backend xmaps
This commit is contained in:
henk717
2022-01-15 16:20:58 +01:00
committed by GitHub
2 changed files with 330 additions and 138 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib
import traceback
import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
import html
@ -993,7 +993,7 @@ else:
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -1001,6 +1001,49 @@ else:
)
return soft_tokens
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
@ -1013,6 +1056,8 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -1020,12 +1065,14 @@ else:
soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),),
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"use_callback": False,
"gen_len": 1,
"numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
},
).start()
@ -2890,32 +2937,69 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator
try:
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
global past
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
while(True):
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer,
context,
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
print("(regeneration triggered)")
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@ -2931,10 +3015,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
genout = past
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
@ -4005,7 +4089,7 @@ def spRequest(filename):
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = np.float32(tensor)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)