Fix a bug with dynamic WI scan when using a soft prompt

The problem was that when a soft prompt is being used, the dynamic
scanning criteria searches a different set of tokens for world info
keys than the `_generate()` function, which results in generation loops
when a world info key appears in the former set of tokens but not the
latter.
This commit is contained in:
Gnome Ann 2022-01-10 15:52:49 -05:00
parent 5fc0509ae3
commit c84d864021
1 changed files with 10 additions and 12 deletions

View File

@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
self,
tokenizer,
excluded_world_info: List[Set],
head_length: int,
):
self.regeneration_required = False
self.halt = False
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
self.head_length = head_length
def __call__(
self,
input_ids: torch.LongTensor,
@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(not vars.dynamicscan):
return self.regeneration_required or self.halt
tail = input_ids[..., self.head_length:]
tail = input_ids[..., -vars.generated_tkns:]
for i, t in enumerate(tail):
decoded = tokenizer.decode(t)
_, found = checkworldinfo(decoded, force_use_txt=True)
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
found -= self.excluded_world_info[i]
if(len(found) != 0):
self.regeneration_required = True
@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
self.kai_scanner = DynamicWorldInfoScanCriteria(
tokenizer=tokenizer,
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.insert(0, self.kai_scanner)
return stopping_criteria
@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries):
else:
gen_in = gen_in.to('cpu')
model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = found_entries
vars._actions = vars.actions
@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries):
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries):
minimum += diff
maximum += diff
gen_in = genout
model.kai_scanner_head_length = encoded.shape[-1]
numseqs = 1
return genout, already_generated
@ -3412,15 +3407,18 @@ def deletewifolder(uid):
#==================================================================#
# Look for WI keys in text to generator
#==================================================================#
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True):
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None):
original_txt = txt
if(actions is None):
actions = vars.actions
# Dont go any further if WI is empty
if(len(vars.worldinfo) == 0):
return "", set()
# Cache actions length
ln = len(vars.actions)
ln = len(actions)
# Don't bother calculating action history if widepth is 0
if(vars.widepth > 0 and scan_story):
@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
if(ln > 0):
chunks = collections.deque()
i = 0
for key in reversed(vars.actions):
chunk = vars.actions[key]
for key in reversed(actions):
chunk = actions[key]
chunks.appendleft(chunk)
i += 1
if(i == depth):