From 2ed7a9dcee05b92ae58c16228b3e731d6a4f5d85 Mon Sep 17 00:00:00 2001 From: somebody Date: Sun, 18 Sep 2022 22:30:43 -0500 Subject: [PATCH] Eternal gen work --- aiserver.py | 157 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 123 insertions(+), 34 deletions(-) diff --git a/aiserver.py b/aiserver.py index 08a07183..85767c09 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2013,6 +2013,49 @@ def patch_transformers(): koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(x[-1])) for x in input_ids]) return False + class CoreStopper(StoppingCriteria): + # Controls core generation stuff; aborting, counting generated tokens, etc + def __init__(self): + self.regeneration_required = False + self.halt = False + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs, + ) -> bool: + koboldai_vars.generated_tkns += 1 + + if ( + not koboldai_vars.standalone + and koboldai_vars.lua_koboldbridge.generated_cols + and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols + ): + print("[TODO] Fix generated_cols") + # raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})") + + if koboldai_vars.abort: + koboldai_vars.abort = False + self.regeneration_required = False + self.halt = False + return True + + if koboldai_vars.standalone: + return False + + assert input_ids.ndim == 2 + + self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required + self.halt = not koboldai_vars.lua_koboldbridge.generating + koboldai_vars.lua_koboldbridge.regeneration_required = False + + for i in range(koboldai_vars.numseqs): + koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item()) + + return self.regeneration_required or self.halt + + # Sets up dynamic world info scanner class DynamicWorldInfoScanCriteria(StoppingCriteria): def __init__( @@ -2024,6 +2067,7 @@ def patch_transformers(): self.halt = False self.tokenizer = tokenizer self.excluded_world_info = excluded_world_info + def __call__( self, input_ids: torch.LongTensor, @@ -2034,35 +2078,38 @@ def patch_transformers(): if not koboldai_vars.inference_config.do_dynamic_wi: return False - koboldai_vars.generated_tkns += 1 - if(not koboldai_vars.standalone and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols): - raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})") - if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt): - self.regeneration_required = False - self.halt = False - koboldai_vars.abort = False - return True - if(koboldai_vars.standalone): + # if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt): + # self.regeneration_required = False + # self.halt = False + # koboldai_vars.abort = False + # return True + + # Pertains to WI I think + # if(koboldai_vars.standalone): + # return False + + # assert input_ids.ndim == 2 + assert len(self.excluded_world_info) == input_ids.shape[0] + # self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required + # self.halt = not koboldai_vars.lua_koboldbridge.generating + # koboldai_vars.lua_koboldbridge.regeneration_required = False + + # for i in range(koboldai_vars.numseqs): + # koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item()) + + if not koboldai_vars.dynamicscan: + #return self.regeneration_required or self.halt return False - assert input_ids.ndim == 2 - assert len(self.excluded_world_info) == input_ids.shape[0] - self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required - self.halt = not koboldai_vars.lua_koboldbridge.generating - koboldai_vars.lua_koboldbridge.regeneration_required = False - - for i in range(koboldai_vars.numseqs): - koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item()) - - if(not koboldai_vars.dynamicscan): - return self.regeneration_required or self.halt tail = input_ids[..., -koboldai_vars.generated_tkns:] for i, t in enumerate(tail): decoded = utils.decodenewlines(tokenizer.decode(t)) _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions) found -= self.excluded_world_info[i] - if(len(found) != 0): - self.regeneration_required = True + if len(found) != 0: + # self.regeneration_required = True + model.core_stopper.regeneration_required = True + return True break return self.regeneration_required or self.halt old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria @@ -2070,12 +2117,14 @@ def patch_transformers(): global tokenizer stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs) + self.core_stopper = CoreStopper() self.kai_scanner = DynamicWorldInfoScanCriteria( tokenizer=tokenizer, excluded_world_info=self.kai_scanner_excluded_world_info, ) token_streamer = TokenStreamer(tokenizer=tokenizer) + stopping_criteria.insert(0, self.core_stopper) stopping_criteria.insert(0, self.kai_scanner) token_streamer = TokenStreamer(tokenizer=tokenizer) stopping_criteria.insert(0, token_streamer) @@ -4811,28 +4860,36 @@ def calcsubmit(txt): def legacy_generate(text: Union[str, list], min: int, max: int): koboldai_vars.lastctx = text - outputs = tpool.execute( + out_batches = tpool.execute( raw_generate, text, max_length=koboldai_vars.genamt, - do_streaming=True + do_streaming=True, + batch_count=koboldai_vars.numseqs, + decode=False ) + decoded_batches = tokenizer.batch_decode(out_batches) + # Lua bridge, genmod - for i, output in enumerate(outputs): - koboldai_vars.lua_koboldbridge.outputs[i + 1] = output + for i in range(koboldai_vars.numseqs): + koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item()) + koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):])) + + # for i, output in enumerate(outputs): + # koboldai_vars.lua_koboldbridge.outputs[i + 1] = output execute_genmod() if koboldai_vars.lua_koboldbridge.regeneration_required: koboldai_vars.lua_koboldbridge.regeneration_required = False genout = [] - for i in range(len(outputs)): + for i in range(len(out_batches)): out = koboldai_vars.lua_koboldbridge.outputs[i + 1] genout.append({"generated_text": out}) assert isinstance(out, str) else: - genout = [{"generated_text": utils.decodenewlines(x)} for x in outputs] + genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches] koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout]) genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()] @@ -4855,12 +4912,17 @@ def raw_generate( do_streaming: bool = False, do_dynamic_wi: bool = False, batch_count: int = 1, + decode: bool = True, ) -> List: if isinstance(prompt, str): prompt_tokens = tokenizer.encode(prompt) else: prompt_tokens = prompt + + # Some gen methods such as OAI don't return tokens. + batch_decoded = None + batch_encoded = None if koboldai_vars.model == "Colab": raise NotImplementedError("Colab API raw_generate unsupported") @@ -4874,13 +4936,15 @@ def raw_generate( raise NotImplementedError("No loaded model") if koboldai_vars.use_colab_tpu or model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"): - batch_out = tpu_raw_generate( + batch_encoded = tpu_raw_generate( prompt_tokens=prompt_tokens, max_length=max_length, batch_count=batch_count ) + elif model == "OAI": + batch_decoded = ... else: - batch_out = torch_raw_generate( + batch_encoded = torch_raw_generate( prompt_tokens=prompt_tokens, max_new=max_length, do_streaming=do_streaming, @@ -4888,9 +4952,21 @@ def raw_generate( batch_count=batch_count ) - decoded = tokenizer.batch_decode(batch_out[:, len(prompt_tokens):]) - return [utils.decodenewlines(x) for x in decoded] + assert batch_encoded or batch_decoded + + # Shave prompt off of encoded response. Decoded does not return prompt. + # TODO: Does MTJ generation shave this off automatically? Test it! + if batch_encoded: + batch_encoded = batch_encoded[:, len(prompt_tokens):] + + if not decode: + return batch_encoded + + if not batch_decoded: + batch_decoded = tokenizer.batch_decode(batch_encoded) + + return [utils.decodenewlines(x) for x in batch_decoded] def tpu_raw_generate( prompt_tokens: List[int], @@ -4961,7 +5037,7 @@ def torch_raw_generate( # Send text to generator and deal with output #==================================================================# -def _generate(txt, minimum, maximum, found_entries): +def old_underscore_generate(txt, minimum, maximum, found_entries): if(koboldai_vars.full_determinism): torch.manual_seed(koboldai_vars.seed) @@ -5000,19 +5076,30 @@ def _generate(txt, minimum, maximum, found_entries): ) already_generated += len(genout[0]) - len(gen_in[0]) assert already_generated <= koboldai_vars.genamt + # If we are halting, we stop if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required): break + # if we require a generation, we continue + assert genout.ndim >= 2 assert genout.shape[0] == koboldai_vars.numseqs + if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols): raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends") + if(already_generated != koboldai_vars.generated_tkns): raise RuntimeError("WI scanning error") + for r in range(koboldai_vars.numseqs): for c in range(already_generated): assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1] + encoded = [] + + # DYNAMIC WI: + # IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION + for i in range(koboldai_vars.numseqs): txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions) @@ -5023,6 +5110,7 @@ def _generate(txt, minimum, maximum, found_entries): else: txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) + max_length = len(max(encoded, key=len)) encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) genout = torch.cat( @@ -5032,6 +5120,7 @@ def _generate(txt, minimum, maximum, found_entries): ), dim=-1 ) + if(koboldai_vars.sp is not None): soft_tokens = torch.arange( model.config.vocab_size, @@ -5049,7 +5138,7 @@ def _generate(txt, minimum, maximum, found_entries): return genout, already_generated -def generate(txt, minimum, maximum, found_entries=None): +def old_generate(txt, minimum, maximum, found_entries=None): koboldai_vars.generated_tkns = 0 if(found_entries is None):