From e0fdce2cc6ea64a300908f8cf36bc7ee037c9fb6 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Fri, 14 Jan 2022 23:00:06 -0500 Subject: [PATCH] Fix TPU generation modifier --- aiserver.py | 30 +++++++++++++++++------------- tpu_mtj_backend.py | 15 ++++++++++++--- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/aiserver.py b/aiserver.py index 96accfa3..f232ff7c 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1001,19 +1001,7 @@ else: ) return soft_tokens - def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]: - vars.generated_tkns += 1 - - assert len(excluded_world_info) == len(generated) - regeneration_required = vars.lua_koboldbridge.regeneration_required - halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt - vars.lua_koboldbridge.regeneration_required = False - - global past - - for i in range(vars.numseqs): - vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) - + def tpumtjgenerate_warper_callback(scores) -> "np.array": scores_shape = scores.shape scores_list = scores.tolist() vars.lua_koboldbridge.logits = vars.lua_state.table() @@ -1029,6 +1017,21 @@ else: ) assert scores.shape == scores_shape + return scores + + def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: + vars.generated_tkns += 1 + + assert len(excluded_world_info) == len(generated) + regeneration_required = vars.lua_koboldbridge.regeneration_required + halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt + vars.lua_koboldbridge.regeneration_required = False + + global past + + for i in range(vars.numseqs): + vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) + if(not vars.dynamicscan or halt): return excluded_world_info, regeneration_required, halt @@ -1054,6 +1057,7 @@ else: assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) import tpu_mtj_backend tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback + tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 9cd49a12..67196645 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor params: Dict[str, Any] = {} -def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]: +def warper_callback(logits) -> np.array: raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") +def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: + raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") + def show_spinner(): bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')]) @@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer): generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) for i in range(numseqs): sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True) + if use_callback: + logits = np.float32(tuple(d[2] for d in sample_data)) + logits = warper_callback(logits) + for i in range(numseqs): + sample_data[i][2] = logits[i] sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) + n_generated += 1 for i in range(numseqs): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) - n_generated += 1 if use_callback: - excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated) + generated = np.uint32(tuple(d[0] for d in sample_data)) + excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info) if regeneration_required or halt: break else: