diff --git a/modeling/inference_models/hf_mtj.py b/modeling/inference_models/hf_mtj.py index 3a16f6f7..9f9b27c4 100644 --- a/modeling/inference_models/hf_mtj.py +++ b/modeling/inference_models/hf_mtj.py @@ -71,11 +71,10 @@ class HFMTJInferenceModel(HFInferenceModel): return scores def mtj_stopping_callback( - generated, n_generated, excluded_world_info + generated, n_generated ) -> Tuple[List[set], bool, bool]: utils.koboldai_vars.generated_tkns += 1 - assert len(excluded_world_info) == len(generated) regeneration_required = ( utils.koboldai_vars.lua_koboldbridge.regeneration_required ) @@ -98,7 +97,7 @@ class HFMTJInferenceModel(HFInferenceModel): ) if not utils.koboldai_vars.dynamicscan or halt: - return excluded_world_info, regeneration_required, halt + return regeneration_required, halt for i, t in enumerate(generated): decoded = utils.decodenewlines( @@ -114,14 +113,16 @@ class HFMTJInferenceModel(HFInferenceModel): ) ) # _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions) - _, _, _, found = utils.koboldai_vars.calc_ai_text( + _, _, _, used_world_info = utils.koboldai_vars.calc_ai_text( submitted_text=decoded ) - found -= excluded_world_info[i] - if len(found) != 0: + print(utils.koboldai_vars.calc_ai_text()) + # found -= excluded_world_info[i] + if used_world_info: + print("lets regen") regeneration_required = True break - return excluded_world_info, regeneration_required, halt + return regeneration_required, halt def mtj_compiling_callback() -> None: print(Colors.GREEN + "TPU backend compilation triggered" + Colors.END) @@ -261,7 +262,7 @@ class HFMTJInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ) -> GenerationResult: soft_tokens = self.get_soft_tokens() @@ -289,19 +290,82 @@ class HFMTJInferenceModel(HFInferenceModel): ) genout = np.array(genout) else: - genout = tpool.execute( - tpu_mtj_backend.infer_dynamic, - context=np.uint32(prompt_tokens), - numseqs=batch_count, - gen_len=max_new, - soft_embeddings=utils.koboldai_vars.sp, - soft_tokens=soft_tokens, - # TODO: Fix Dynamic WI on TPU - excluded_world_info=set(), - use_callback=True + global past + context = np.tile( + np.uint32(prompt_tokens), (utils.koboldai_vars.numseqs, 1) ) - print(genout) - print(type(genout)) + past = np.empty((utils.koboldai_vars.numseqs, 0), dtype=np.uint32) + self.gen_state["wi_scanner_excluded_keys"] = set() + + while True: + genout, n_generated, regeneration_required, halt = tpool.execute( + tpu_mtj_backend.infer_dynamic, + context, + gen_len=max_new, + numseqs=utils.koboldai_vars.numseqs, + soft_embeddings=utils.koboldai_vars.sp, + soft_tokens=soft_tokens, + ) + + past = np.pad(past, ((0, 0), (0, n_generated))) + for r in range(utils.koboldai_vars.numseqs): + for c in range(utils.koboldai_vars.lua_koboldbridge.generated_cols): + assert ( + utils.koboldai_vars.lua_koboldbridge.generated[r + 1][c + 1] + is not None + ) + past[r, c] = utils.koboldai_vars.lua_koboldbridge.generated[ + r + 1 + ][c + 1] + + if utils.koboldai_vars.abort or halt or not regeneration_required: + break + + print("(regeneration triggered)") + + encoded = [] + for i in range(utils.koboldai_vars.numseqs): + txt = utils.decodenewlines(self.tokenizer.decode(past[i])) + # _, _, _, _found_entries = utils.koboldai_vars.calc_ai_text( + # self.tokenizer.decode(prompt_tokens) + # ) + # # utils.koboldai_vars.calc_ai_text() + # print(_found_entries) + # self.gen_state["wi_scanner_excluded_keys"].update(_found_entries) + encoded.append(np.array(txt, dtype=np.uint32)) + + max_length = len(max(encoded, key=len)) + encoded = np.stack( + tuple( + np.pad( + e, + (max_length - len(e), 0), + constant_values=tpu_mtj_backend.pad_token_id, + ) + for e in encoded + ) + ) + context = np.concatenate( + ( + encoded, + past, + ), + axis=-1, + ) + # genout = tpool.execute( + # tpu_mtj_backend.infer_dynamic, + # context=np.uint32(prompt_tokens), + # numseqs=batch_count, + # gen_len=max_new, + # soft_embeddings=utils.koboldai_vars.sp, + # soft_tokens=soft_tokens, + # # TODO: Fix Dynamic WI on TPU + # excluded_world_info=set(), + # use_callback=True + # ) + # print(genout) + # print(type(genout)) + print(context) genout = np.array(genout) return GenerationResult( diff --git a/modeling/warpers.py b/modeling/warpers.py index 0885842c..f0b0fbcb 100644 --- a/modeling/warpers.py +++ b/modeling/warpers.py @@ -258,7 +258,7 @@ class TopK(Warper): @classmethod def value_is_valid(cls) -> bool: - return cls.top_p > 0 + return cls.top_k > 0 class TailFree(Warper): diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index c01210cd..a20cb213 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -89,7 +89,7 @@ def new_rng_state(seed: int): def warper_callback(logits) -> np.array: raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") -def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: +def stopping_callback(generated, n_generated) -> Tuple[bool, bool]: raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") def settings_callback() -> dict: @@ -219,7 +219,7 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra warper = warpers.Warper.from_id(sid) if not warper.value_is_valid(): continue - logits = warper.jax_dynamic() + logits = warper.jax_dynamic(logits) # Finally, pick one token using the softmax thingy again (it gives # an array whose elements sum to 1 so it can be used nicely as a @@ -473,8 +473,7 @@ class PenalizingCausalTransformer(CausalTransformer): out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): - assert excluded_world_info is not None + def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, return_logits=False, soft_embeddings=None, use_callback=True): assert not return_logits assert gen_length.ndim == 1 assert soft_embeddings is not None @@ -517,7 +516,7 @@ class PenalizingCausalTransformer(CausalTransformer): generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) if use_callback: generated = np.uint32(tuple(d[0] for d in sample_data)) - excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info) + regeneration_required, halt = stopping_callback(generated, n_generated) if regeneration_required or halt: break else: @@ -550,10 +549,8 @@ def infer_dynamic( gen_len=80, soft_embeddings: Optional[np.array] = None, soft_tokens: Optional[np.array] = None, - excluded_world_info = None, use_callback=True, ) -> Tuple[List[np.array], int, bool, bool]: - assert excluded_world_info is not None maps.thread_resources.env = thread_resources_env total_batch = 1 tokens = context @@ -570,7 +567,6 @@ def infer_dynamic( np.ones(total_batch, dtype=np.uint32) * gen_len, numseqs, soft_embeddings=soft_embeddings, - excluded_world_info=excluded_world_info, use_callback=use_callback, ) for out in output[0]: