diff --git a/aiserver.py b/aiserver.py index 85767c09..45e4b639 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2032,10 +2032,16 @@ def patch_transformers(): and koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols ): - print("[TODO] Fix generated_cols") - # raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})") + raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({koboldai_vars.generated_tkns} != {koboldai_vars.lua_koboldbridge.generated_cols})") - if koboldai_vars.abort: + if ( + koboldai_vars.abort + or ( + koboldai_vars.inference_config.stop_at_genamt + and + koboldai_vars.generated_tkns >= koboldai_vars.genamt + ) + ): koboldai_vars.abort = False self.regeneration_required = False self.halt = False @@ -2063,8 +2069,8 @@ def patch_transformers(): tokenizer, excluded_world_info: List[Set], ): - self.regeneration_required = False - self.halt = False + # self.regeneration_required = False + # self.halt = False self.tokenizer = tokenizer self.excluded_world_info = excluded_world_info @@ -2078,27 +2084,9 @@ def patch_transformers(): if not koboldai_vars.inference_config.do_dynamic_wi: return False - # if(koboldai_vars.abort or koboldai_vars.generated_tkns >= koboldai_vars.genamt): - # self.regeneration_required = False - # self.halt = False - # koboldai_vars.abort = False - # return True - - # Pertains to WI I think - # if(koboldai_vars.standalone): - # return False - - # assert input_ids.ndim == 2 assert len(self.excluded_world_info) == input_ids.shape[0] - # self.regeneration_required = koboldai_vars.lua_koboldbridge.regeneration_required - # self.halt = not koboldai_vars.lua_koboldbridge.generating - # koboldai_vars.lua_koboldbridge.regeneration_required = False - - # for i in range(koboldai_vars.numseqs): - # koboldai_vars.lua_koboldbridge.generated[i+1][koboldai_vars.generated_tkns] = int(input_ids[i, -1].item()) if not koboldai_vars.dynamicscan: - #return self.regeneration_required or self.halt return False tail = input_ids[..., -koboldai_vars.generated_tkns:] @@ -2107,11 +2095,10 @@ def patch_transformers(): _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions) found -= self.excluded_world_info[i] if len(found) != 0: - # self.regeneration_required = True model.core_stopper.regeneration_required = True return True - break - return self.regeneration_required or self.halt + return False + old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria def new_get_stopping_criteria(self, *args, **kwargs): global tokenizer @@ -4789,7 +4776,7 @@ def calcsubmit(txt): "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX" ): - legacy_generate(subtxt, min, max) + generate(subtxt, min, max, found_entries) elif koboldai_vars.model == "Colab": sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif koboldai_vars.model == "API": @@ -4857,52 +4844,117 @@ def calcsubmit(txt): # Send it! ikrequest(subtxt) -def legacy_generate(text: Union[str, list], min: int, max: int): - koboldai_vars.lastctx = text +def core_generate(text: list, min: int, max: int, found_entries: set): + # This generation function is tangled with koboldai_vars intentionally. It + # is meant for the story and nothing else. - out_batches = tpool.execute( - raw_generate, - text, - max_length=koboldai_vars.genamt, - do_streaming=True, - batch_count=koboldai_vars.numseqs, - decode=False - ) + if koboldai_vars.full_determinism: + torch.manual_seed(koboldai_vars.seed) - decoded_batches = tokenizer.batch_decode(out_batches) + gen_in = torch.tensor(text, dtype=torch.long)[None] + if koboldai_vars.sp is not None: + soft_tokens = torch.arange( + model.config.vocab_size, + model.config.vocab_size + koboldai_vars.sp.shape[0], + ) + gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) - # Lua bridge, genmod - for i in range(koboldai_vars.numseqs): - koboldai_vars.lua_koboldbridge.generated[i + 1][koboldai_vars.generated_tkns] = int(out_batches[i, -1].item()) - koboldai_vars.lua_koboldbridge.outputs[i + 1] = utils.decodenewlines(tokenizer.decode(out_batches[i, -len(out_batches[i]):])) + assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length - # for i, output in enumerate(outputs): - # koboldai_vars.lua_koboldbridge.outputs[i + 1] = output - - execute_genmod() - - if koboldai_vars.lua_koboldbridge.regeneration_required: - koboldai_vars.lua_koboldbridge.regeneration_required = False - genout = [] - for i in range(len(out_batches)): - out = koboldai_vars.lua_koboldbridge.outputs[i + 1] - genout.append({"generated_text": out}) - assert isinstance(out, str) + if koboldai_vars.hascuda and koboldai_vars.usegpu: + gen_in = gen_in.to(koboldai_vars.gpu_device) + elif koboldai_vars.hascuda and koboldai_vars.breakmodel: + gen_in = gen_in.to(breakmodel.primary_device) else: - genout = [{"generated_text": utils.decodenewlines(x)} for x in decoded_batches] + gen_in = gen_in.to("cpu") - koboldai_vars.actions.append_options([applyoutputformatting(x["generated_text"]) for x in genout]) - genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()] + found_entries = found_entries or set() + model.kai_scanner_excluded_world_info = found_entries - if len(genout) == 1: - genresult(genout[0]["generated_text"]) - else: - restart_seq = koboldai_vars.lua_koboldbridge.restart_sequence - if restart_seq and restart_seq > 0: - genresult(genout[restart_seq - 1]["generated_text"]) - else: - genselect(genout) - set_aibusy(0) + koboldai_vars._prompt = koboldai_vars.prompt + + with torch.no_grad(): + already_generated = 0 + numseqs = koboldai_vars.numseqs + while True: + # The reason this is a loop is due to how Dynamic WI works. We + # cannot simply add the WI to the context mid-generation, so we + # stop early, and then insert WI, then continue generating. That + # stopping and continuing is this loop. + + genout = raw_generate( + gen_in, + # Real max length is handled by CoreStopper. + max_length=int(2e9), + do_streaming=True, + do_dynamic_wi=True, + batch_count=numseqs, + decode=False, + ) + already_generated += len(genout[0]) - len(gen_in[0]) + assert already_generated <= koboldai_vars.genamt + + # Generation stopped; why? + # If we have been told to halt, we have reached our target token + # amount (controlled by halt), or Dynamic WI has not told us to + # stop temporarily to insert WI, we can assume that we are done + # generating. We shall break. + if model.core_stopper.halt or not model.core_stopper.regeneration_required: + break + + # Now we are doing stuff for Dynamic WI. + assert genout.ndim >= 2 + assert genout.shape[0] == koboldai_vars.numseqs + + if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols): + raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends") + + if(already_generated != koboldai_vars.generated_tkns): + raise RuntimeError("WI scanning error") + + for r in range(koboldai_vars.numseqs): + for c in range(already_generated): + assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None + genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1] + + encoded = [] + + for i in range(koboldai_vars.numseqs): + txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions) + found_entries[i].update(_found_entries) + if koboldai_vars.alt_gen: + txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt) + print("Using Alt Gen") + else: + txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) + encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) + + max_length = len(max(encoded, key=len)) + encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) + genout = torch.cat( + ( + encoded, + genout[..., -already_generated:], + ), + dim=-1 + ) + + if(koboldai_vars.sp is not None): + soft_tokens = torch.arange( + model.config.vocab_size, + model.config.vocab_size + koboldai_vars.sp.shape[0], + device=genout.device, + ) + genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1) + assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length + diff = genout.shape[-1] - gen_in.shape[-1] + minimum += diff + maximum += diff + gen_in = genout + numseqs = 1 + + return genout, already_generated def raw_generate( # prompt is either a string (text) or a list (token ids) @@ -4952,18 +5004,17 @@ def raw_generate( batch_count=batch_count ) - - assert batch_encoded or batch_decoded + assert batch_encoded is not None or batch_decoded is not None # Shave prompt off of encoded response. Decoded does not return prompt. # TODO: Does MTJ generation shave this off automatically? Test it! - if batch_encoded: - batch_encoded = batch_encoded[:, len(prompt_tokens):] + if batch_encoded is not None: + batch_encoded = batch_encoded[:, len(prompt_tokens) - 1:] if not decode: return batch_encoded - if not batch_decoded: + if batch_decoded is None: batch_decoded = tokenizer.batch_decode(batch_encoded) return [utils.decodenewlines(x) for x in batch_decoded] @@ -4997,7 +5048,7 @@ def tpu_raw_generate( return genout def torch_raw_generate( - prompt_tokens: List[int], + prompt_tokens: Union[List[int], torch.Tensor], max_new: int, do_streaming: bool = False, @@ -5008,10 +5059,16 @@ def torch_raw_generate( koboldai_vars.inference_config.do_streaming = do_streaming koboldai_vars.inference_config.do_dynamic_wi = do_dynamic_wi - # Makes stopping criteria hook happy - model.kai_scanner_excluded_world_info = [] + # Dynamic WI depends on this!!! This is a main gen call. + koboldai_vars.inference_config.stop_at_genamt = do_dynamic_wi - gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] + # Makes stopping criteria hook happy + model.kai_scanner_excluded_world_info = model.kai_scanner_excluded_world_info or set() + + if not isinstance(prompt_tokens, torch.Tensor): + gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] + else: + gen_in = prompt_tokens device = "cpu" if koboldai_vars.hascuda and koboldai_vars.usegpu: @@ -5037,108 +5094,7 @@ def torch_raw_generate( # Send text to generator and deal with output #==================================================================# -def old_underscore_generate(txt, minimum, maximum, found_entries): - if(koboldai_vars.full_determinism): - torch.manual_seed(koboldai_vars.seed) - - gen_in = torch.tensor(txt, dtype=torch.long)[None] - if(koboldai_vars.sp is not None): - soft_tokens = torch.arange( - model.config.vocab_size, - model.config.vocab_size + koboldai_vars.sp.shape[0], - ) - gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) - assert gen_in.shape[-1] + koboldai_vars.genamt <= koboldai_vars.max_length - - if(koboldai_vars.hascuda and koboldai_vars.usegpu): - gen_in = gen_in.to(koboldai_vars.gpu_device) - elif(koboldai_vars.hascuda and koboldai_vars.breakmodel): - gen_in = gen_in.to(breakmodel.primary_device) - else: - gen_in = gen_in.to('cpu') - - model.kai_scanner_excluded_world_info = found_entries - - koboldai_vars._prompt = koboldai_vars.prompt - - with torch.no_grad(): - already_generated = 0 - numseqs = koboldai_vars.numseqs - while True: - genout = generator( - gen_in, - do_sample=True, - max_length=int(2e9), - repetition_penalty=1.0, - bad_words_ids=koboldai_vars.badwordsids, - use_cache=True, - num_return_sequences=numseqs - ) - already_generated += len(genout[0]) - len(gen_in[0]) - assert already_generated <= koboldai_vars.genamt - # If we are halting, we stop - if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required): - break - # if we require a generation, we continue - - assert genout.ndim >= 2 - assert genout.shape[0] == koboldai_vars.numseqs - - if(koboldai_vars.lua_koboldbridge.generated_cols and koboldai_vars.generated_tkns != koboldai_vars.lua_koboldbridge.generated_cols): - raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends") - - if(already_generated != koboldai_vars.generated_tkns): - raise RuntimeError("WI scanning error") - - for r in range(koboldai_vars.numseqs): - for c in range(already_generated): - assert koboldai_vars.lua_koboldbridge.generated[r+1][c+1] is not None - genout[r][genout.shape[-1] - already_generated + c] = koboldai_vars.lua_koboldbridge.generated[r+1][c+1] - - encoded = [] - - # DYNAMIC WI: - # IF WE FIND WORLD INFO MID-GENERATION, STOP, THEN ADD WI AND ADD NEW GENERATION - - for i in range(koboldai_vars.numseqs): - txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) - winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=koboldai_vars.actions) - found_entries[i].update(_found_entries) - if koboldai_vars.alt_gen: - txt, _, _ = koboldai_vars.calc_ai_text(submitted_text=txt) - print("Using Alt Gen") - else: - txt, _, _ = calcsubmitbudget(len(koboldai_vars.actions), winfo, mem, anotetxt, koboldai_vars.actions, submission=txt) - encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) - - max_length = len(max(encoded, key=len)) - encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) - genout = torch.cat( - ( - encoded, - genout[..., -already_generated:], - ), - dim=-1 - ) - - if(koboldai_vars.sp is not None): - soft_tokens = torch.arange( - model.config.vocab_size, - model.config.vocab_size + koboldai_vars.sp.shape[0], - device=genout.device, - ) - genout = torch.cat((soft_tokens.tile(koboldai_vars.numseqs, 1), genout), dim=-1) - assert genout.shape[-1] + koboldai_vars.genamt - already_generated <= koboldai_vars.max_length - diff = genout.shape[-1] - gen_in.shape[-1] - minimum += diff - maximum += diff - gen_in = genout - numseqs = 1 - - return genout, already_generated - - -def old_generate(txt, minimum, maximum, found_entries=None): +def generate(txt, minimum, maximum, found_entries=None): koboldai_vars.generated_tkns = 0 if(found_entries is None): @@ -5158,7 +5114,7 @@ def old_generate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: - genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, found_entries) + genout, already_generated = tpool.execute(core_generate, txt, minimum, maximum, found_entries) except Exception as e: if(issubclass(type(e), lupa.LuaError)): koboldai_vars.lua_koboldbridge.obliterate_multiverse() diff --git a/koboldai_settings.py b/koboldai_settings.py index b4553fdd..9aa539e3 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -800,11 +800,9 @@ class system_settings(settings): @dataclass class _inference_config: do_streaming: bool = False - - # NOTE: DynamicWorldInfoScanCriteria handles not only dynamic world - # info, but also max length, aborting, regeneration requests, etc - # for kobold-rooted stuff. This would be nice to change in the future. do_dynamic_wi: bool = False + # Genamt stopping is mostly tied to Dynamic WI + stop_at_genamt: bool = False self.inference_config = _inference_config()