Fix a bug with dynamic WI scan when using a soft prompt

The problem was that when a soft prompt is being used, the dynamic
scanning criteria searches a different set of tokens for world info
keys than the `_generate()` function, which results in generation loops
when a world info key appears in the former set of tokens but not the
latter.
This commit is contained in:
Gnome Ann 2022-01-10 15:52:49 -05:00
parent 5fc0509ae3
commit c84d864021
1 changed files with 10 additions and 12 deletions

View File

@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
self, self,
tokenizer, tokenizer,
excluded_world_info: List[Set], excluded_world_info: List[Set],
head_length: int,
): ):
self.regeneration_required = False self.regeneration_required = False
self.halt = False self.halt = False
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__( def __call__(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(not vars.dynamicscan): if(not vars.dynamicscan):
return self.regeneration_required or self.halt return self.regeneration_required or self.halt
tail = input_ids[..., self.head_length:] tail = input_ids[..., -vars.generated_tkns:]
for i, t in enumerate(tail): for i, t in enumerate(tail):
decoded = tokenizer.decode(t) decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True) _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
found -= self.excluded_world_info[i] found -= self.excluded_world_info[i]
if(len(found) != 0): if(len(found) != 0):
self.regeneration_required = True self.regeneration_required = True
@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
self.kai_scanner = DynamicWorldInfoScanCriteria( self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer, tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info, excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
) )
stopping_criteria.insert(0, self.kai_scanner) stopping_criteria.insert(0, self.kai_scanner)
return stopping_criteria return stopping_criteria
@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries):
else: else:
gen_in = gen_in.to('cpu') gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries model.kai_scanner_excluded_world_info = found_entries
vars._actions = vars.actions vars._actions = vars.actions
@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries):
encoded = [] encoded = []
for i in range(vars.numseqs): for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:]) txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries):
minimum += diff minimum += diff
maximum += diff maximum += diff
gen_in = genout gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1 numseqs = 1
return genout, already_generated return genout, already_generated
@ -3412,15 +3407,18 @@ def deletewifolder(uid):
#==================================================================# #==================================================================#
# Look for WI keys in text to generator # Look for WI keys in text to generator
#==================================================================# #==================================================================#
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True): def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None):
original_txt = txt original_txt = txt
if(actions is None):
actions = vars.actions
# Dont go any further if WI is empty # Dont go any further if WI is empty
if(len(vars.worldinfo) == 0): if(len(vars.worldinfo) == 0):
return "", set() return "", set()
# Cache actions length # Cache actions length
ln = len(vars.actions) ln = len(actions)
# Don't bother calculating action history if widepth is 0 # Don't bother calculating action history if widepth is 0
if(vars.widepth > 0 and scan_story): if(vars.widepth > 0 and scan_story):
@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
if(ln > 0): if(ln > 0):
chunks = collections.deque() chunks = collections.deque()
i = 0 i = 0
for key in reversed(vars.actions): for key in reversed(actions):
chunk = vars.actions[key] chunk = actions[key]
chunks.appendleft(chunk) chunks.appendleft(chunk)
i += 1 i += 1
if(i == depth): if(i == depth):