mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add TPU support for dynamic WI scan and generation modifiers
This commit is contained in:
121
aiserver.py
121
aiserver.py
@ -22,7 +22,7 @@ import packaging
|
||||
import contextlib
|
||||
import traceback
|
||||
import threading
|
||||
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
@ -1001,6 +1001,46 @@ else:
|
||||
)
|
||||
return soft_tokens
|
||||
|
||||
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
|
||||
vars.generated_tkns += 1
|
||||
|
||||
assert len(excluded_world_info) == len(generated)
|
||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
global past
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||
for r, row in enumerate(scores_list):
|
||||
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
execute_genmod()
|
||||
|
||||
scores = np.array(
|
||||
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
if(not vars.dynamicscan or halt):
|
||||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
for i, t in enumerate(generated):
|
||||
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
found -= excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
regeneration_required = True
|
||||
break
|
||||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
# If we're running Colab or OAI, we still need a tokenizer.
|
||||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
@ -1013,6 +1053,7 @@ else:
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
@ -1020,12 +1061,14 @@ else:
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
threading.Thread( # Compile backend code in background
|
||||
target=tpu_mtj_backend.infer,
|
||||
args=(np.uint32((23403, 727, 20185)),),
|
||||
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
|
||||
kwargs={
|
||||
"soft_embeddings": vars.sp,
|
||||
"soft_tokens": soft_tokens,
|
||||
"use_callback": False,
|
||||
"gen_len": 1,
|
||||
"numseqs": vars.numseqs,
|
||||
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
|
||||
},
|
||||
).start()
|
||||
|
||||
@ -2890,32 +2933,68 @@ def sendtocolab(txt, min, max):
|
||||
# Send text to TPU mesh transformer backend
|
||||
#==================================================================#
|
||||
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||
|
||||
vars._actions = vars.actions
|
||||
vars._prompt = vars.prompt
|
||||
if(vars.dynamicscan):
|
||||
vars._actions = vars._actions.copy()
|
||||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
if(vars.dynamicscan):
|
||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
||||
|
||||
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer,
|
||||
np.uint32(txt),
|
||||
gen_len = maximum-minimum+1,
|
||||
temp=vars.temp,
|
||||
top_p=vars.top_p,
|
||||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
)
|
||||
global past
|
||||
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
|
||||
|
||||
while(True):
|
||||
genout, n_generated, regeneration_required, halt = tpool.execute(
|
||||
tpu_mtj_backend.infer,
|
||||
context,
|
||||
gen_len = maximum-minimum+1,
|
||||
temp=vars.temp,
|
||||
top_p=vars.top_p,
|
||||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
excluded_world_info=found_entries,
|
||||
)
|
||||
|
||||
past = np.pad(past, ((0, 0), (0, n_generated)))
|
||||
for r in range(vars.numseqs):
|
||||
for c in range(vars.lua_koboldbridge.generated_cols):
|
||||
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
|
||||
if(halt or not regeneration_required):
|
||||
break
|
||||
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(past[i])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
encoded.append(np.array(txt, dtype=np.uint32))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
|
||||
context = np.concatenate(
|
||||
(
|
||||
encoded,
|
||||
past,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if(issubclass(type(e), lupa.LuaError)):
|
||||
@ -2931,10 +3010,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||
set_aibusy(0)
|
||||
return
|
||||
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
|
||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
|
||||
genout = past
|
||||
|
||||
execute_outmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
|
Reference in New Issue
Block a user