Fix TPU generation modifier

This commit is contained in:
Gnome Ann 2022-01-14 23:00:06 -05:00
parent 932c393d6a
commit e0fdce2cc6
2 changed files with 29 additions and 16 deletions

View File

@ -1001,19 +1001,7 @@ else:
)
return soft_tokens
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
@ -1029,6 +1017,21 @@ else:
)
assert scores.shape == scores_shape
return scores
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
@ -1054,6 +1057,7 @@ else:
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])

View File

@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
params: Dict[str, Any] = {}
def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]:
def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
def show_spinner():
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='')])
@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer):
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
for i in range(numseqs):
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
if use_callback:
logits = np.float32(tuple(d[2] for d in sample_data))
logits = warper_callback(logits)
for i in range(numseqs):
sample_data[i][2] = logits[i]
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
n_generated += 1
for i in range(numseqs):
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
n_generated += 1
if use_callback:
excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated)
generated = np.uint32(tuple(d[0] for d in sample_data))
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
if regeneration_required or halt:
break
else: