Fix a bug with dynamic WI scan when using a soft prompt
The problem was that when a soft prompt is being used, the dynamic scanning criteria searches a different set of tokens for world info keys than the `_generate()` function, which results in generation loops when a world info key appears in the former set of tokens but not the latter.
This commit is contained in:
parent
5fc0509ae3
commit
c84d864021
22
aiserver.py
22
aiserver.py
|
@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: List[Set],
|
||||
head_length: int,
|
||||
):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
self.head_length = head_length
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
|
||||
if(not vars.dynamicscan):
|
||||
return self.regeneration_required or self.halt
|
||||
tail = input_ids[..., self.head_length:]
|
||||
tail = input_ids[..., -vars.generated_tkns:]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = tokenizer.decode(t)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
self.regeneration_required = True
|
||||
|
@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||
tokenizer=tokenizer,
|
||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||
head_length=self.kai_scanner_head_length,
|
||||
)
|
||||
stopping_criteria.insert(0, self.kai_scanner)
|
||||
return stopping_criteria
|
||||
|
@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||
else:
|
||||
gen_in = gen_in.to('cpu')
|
||||
|
||||
model.kai_scanner_head_length = gen_in.shape[-1]
|
||||
model.kai_scanner_excluded_world_info = found_entries
|
||||
|
||||
vars._actions = vars.actions
|
||||
|
@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
|
@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||
minimum += diff
|
||||
maximum += diff
|
||||
gen_in = genout
|
||||
model.kai_scanner_head_length = encoded.shape[-1]
|
||||
numseqs = 1
|
||||
|
||||
return genout, already_generated
|
||||
|
@ -3412,15 +3407,18 @@ def deletewifolder(uid):
|
|||
#==================================================================#
|
||||
# Look for WI keys in text to generator
|
||||
#==================================================================#
|
||||
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True):
|
||||
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None):
|
||||
original_txt = txt
|
||||
|
||||
if(actions is None):
|
||||
actions = vars.actions
|
||||
|
||||
# Dont go any further if WI is empty
|
||||
if(len(vars.worldinfo) == 0):
|
||||
return "", set()
|
||||
|
||||
# Cache actions length
|
||||
ln = len(vars.actions)
|
||||
ln = len(actions)
|
||||
|
||||
# Don't bother calculating action history if widepth is 0
|
||||
if(vars.widepth > 0 and scan_story):
|
||||
|
@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
|
|||
if(ln > 0):
|
||||
chunks = collections.deque()
|
||||
i = 0
|
||||
for key in reversed(vars.actions):
|
||||
chunk = vars.actions[key]
|
||||
for key in reversed(actions):
|
||||
chunk = actions[key]
|
||||
chunks.appendleft(chunk)
|
||||
i += 1
|
||||
if(i == depth):
|
||||
|
|
Loading…
Reference in New Issue