From a1bc10246c355631c9a9f33138a8c88b99cc0201 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Wed, 17 Nov 2021 16:17:59 -0500 Subject: [PATCH] Support for multiple gens per action with dynamic scan --- aiserver.py | 47 ++++++++++++++++++++++++++++------------------- gensettings.py | 2 +- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/aiserver.py b/aiserver.py index ef9bf8ed..f1f1c713 100644 --- a/aiserver.py +++ b/aiserver.py @@ -14,7 +14,7 @@ from tkinter import messagebox import json import collections import zipfile -from typing import Union, Dict, Set +from typing import Union, Dict, Set, List import requests import html @@ -565,8 +565,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): def __init__( self, tokenizer, - excluded_world_info: set, - #head_length: torch.LongTensor, + excluded_world_info: List[Set], head_length: int, ): self.any_new_entries = False @@ -580,15 +579,15 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly"]): **kwargs, ) -> bool: assert input_ids.ndim == 2 - #assert input_ids.shape[:-1] == self.head_length.shape + assert len(self.excluded_world_info) == input_ids.shape[0] self.any_new_entries = False if(not vars.dynamicscan): return False tail = input_ids[..., self.head_length:] - for t in tail: + for i, t in enumerate(tail): decoded = tokenizer.decode(t) _, found = checkworldinfo(decoded, force_use_txt=True) - found -= self.excluded_world_info + found -= self.excluded_world_info[i] if(len(found) != 0): self.any_new_entries = True break @@ -1423,8 +1422,12 @@ def calcsubmit(txt): #==================================================================# # Send text to generator and deal with output #==================================================================# -def generate(txt, min, max, found_entries=set()): - print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) +def generate(txt, minimum, maximum, found_entries=None): + if(found_entries is None): + found_entries = set() + found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) + + print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END)) # Store context in memory to use it for comparison with generated content vars.lastctx = txt @@ -1466,13 +1469,13 @@ def generate(txt, min, max, found_entries=set()): with torch.no_grad(): already_generated = 0 - numseqs = vars.numseqs if not vars.dynamicscan else 1 + numseqs = vars.numseqs while True: genout = generator( gen_in, do_sample=True, - min_length=min, - max_length=max-already_generated, + min_length=minimum, + max_length=maximum-already_generated, repetition_penalty=vars.rep_pen, top_p=top_p, top_k=top_k, @@ -1485,11 +1488,17 @@ def generate(txt, min, max, found_entries=set()): already_generated += len(genout[0]) - len(gen_in[0]) if(not model.kai_scanner.any_new_entries): break - txt = tokenizer.decode(genout[0, -already_generated:]) - winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) - found_entries |= _found_entries - txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) - encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device) + assert genout.ndim >= 2 + assert genout.shape[0] == vars.numseqs + encoded = [] + for i in range(vars.numseqs): + txt = tokenizer.decode(genout[i, -already_generated:]) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + found_entries[i].update(_found_entries) + txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) + encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device)) + max_length = len(max(encoded, key=len)) + encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) genout = torch.cat( ( encoded, @@ -1503,10 +1512,10 @@ def generate(txt, min, max, found_entries=set()): model.config.vocab_size + vars.sp.shape[0], device=genout.device, ) - genout = torch.cat((soft_tokens[None], genout), dim=-1) + genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1) diff = genout.shape[-1] - gen_in.shape[-1] - min += diff - max += diff + minimum += diff + maximum += diff gen_in = genout model.kai_scanner_head_length = encoded.shape[-1] numseqs = 1 diff --git a/gensettings.py b/gensettings.py index c567c94c..e23bc766 100644 --- a/gensettings.py +++ b/gensettings.py @@ -128,7 +128,7 @@ gensettingstf = [{ "max": 1, "step": 1, "default": 0, - "tooltip": "Scan the AI's output for world info keys as it's generating the output. Turning this on will set Gens Per Action to 1, as these two features are not currently compatible with each other." + "tooltip": "Scan the AI's output for world info keys as it's generating the output." }] gensettingsik =[{