Dynamic WI scanner should ignore triggers that are already in context

This commit is contained in:
Gnome Ann
2021-11-03 18:55:53 -04:00
parent ecfbbdb4a9
commit a2d7735a51

View File

@ -1318,21 +1318,21 @@ def calcsubmit(txt):
anoteadded = False # In case our budget runs out before we hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth
actionlen = len(vars.actions) actionlen = len(vars.actions)
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(txt) winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
# For all transformers models # For all transformers models
if(vars.model != "InferKit"): if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt) subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt)
if(actionlen == 0): if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]): if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, min, max) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max) sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(subtxt, min, max) oairequest(subtxt, min, max)
else: else:
if(not vars.model in ["Colab", "OAI"]): if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, min, max) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max) sendtocolab(subtxt, min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
@ -1397,7 +1397,7 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
# Send text to generator and deal with output # Send text to generator and deal with output
#==================================================================# #==================================================================#
def generate(txt, min, max): def generate(txt, min, max, found_entries=set()):
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, min, max, txt, colors.END))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
@ -1432,7 +1432,7 @@ def generate(txt, min, max):
gen_in = gen_in.to('cpu') gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1] model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = set() model.kai_scanner_excluded_world_info = found_entries
actions = vars.actions actions = vars.actions
if(vars.dynamicscan): if(vars.dynamicscan):
@ -1441,7 +1441,6 @@ def generate(txt, min, max):
with torch.no_grad(): with torch.no_grad():
already_generated = 0 already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1 numseqs = vars.numseqs if not vars.dynamicscan else 1
found_entries = model.kai_scanner_excluded_world_info
while True: while True:
genout = generator( genout = generator(
gen_in, gen_in,