Fix a bug with dynamic WI scan when using a soft prompt
The problem was that when a soft prompt is being used, the dynamic scanning criteria searches a different set of tokens for world info keys than the `_generate()` function, which results in generation loops when a world info key appears in the former set of tokens but not the latter.
This commit is contained in:
parent
5fc0509ae3
commit
c84d864021
22
aiserver.py
22
aiserver.py
|
@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
excluded_world_info: List[Set],
|
excluded_world_info: List[Set],
|
||||||
head_length: int,
|
|
||||||
):
|
):
|
||||||
self.regeneration_required = False
|
self.regeneration_required = False
|
||||||
self.halt = False
|
self.halt = False
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.excluded_world_info = excluded_world_info
|
self.excluded_world_info = excluded_world_info
|
||||||
self.head_length = head_length
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
|
|
||||||
if(not vars.dynamicscan):
|
if(not vars.dynamicscan):
|
||||||
return self.regeneration_required or self.halt
|
return self.regeneration_required or self.halt
|
||||||
tail = input_ids[..., self.head_length:]
|
tail = input_ids[..., -vars.generated_tkns:]
|
||||||
for i, t in enumerate(tail):
|
for i, t in enumerate(tail):
|
||||||
decoded = tokenizer.decode(t)
|
decoded = tokenizer.decode(t)
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||||
found -= self.excluded_world_info[i]
|
found -= self.excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
self.regeneration_required = True
|
self.regeneration_required = True
|
||||||
|
@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
self.kai_scanner = DynamicWorldInfoScanCriteria(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
excluded_world_info=self.kai_scanner_excluded_world_info,
|
excluded_world_info=self.kai_scanner_excluded_world_info,
|
||||||
head_length=self.kai_scanner_head_length,
|
|
||||||
)
|
)
|
||||||
stopping_criteria.insert(0, self.kai_scanner)
|
stopping_criteria.insert(0, self.kai_scanner)
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
|
@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||||
else:
|
else:
|
||||||
gen_in = gen_in.to('cpu')
|
gen_in = gen_in.to('cpu')
|
||||||
|
|
||||||
model.kai_scanner_head_length = gen_in.shape[-1]
|
|
||||||
model.kai_scanner_excluded_world_info = found_entries
|
model.kai_scanner_excluded_world_info = found_entries
|
||||||
|
|
||||||
vars._actions = vars.actions
|
vars._actions = vars.actions
|
||||||
|
@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||||
|
@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||||
minimum += diff
|
minimum += diff
|
||||||
maximum += diff
|
maximum += diff
|
||||||
gen_in = genout
|
gen_in = genout
|
||||||
model.kai_scanner_head_length = encoded.shape[-1]
|
|
||||||
numseqs = 1
|
numseqs = 1
|
||||||
|
|
||||||
return genout, already_generated
|
return genout, already_generated
|
||||||
|
@ -3412,15 +3407,18 @@ def deletewifolder(uid):
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Look for WI keys in text to generator
|
# Look for WI keys in text to generator
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True):
|
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None):
|
||||||
original_txt = txt
|
original_txt = txt
|
||||||
|
|
||||||
|
if(actions is None):
|
||||||
|
actions = vars.actions
|
||||||
|
|
||||||
# Dont go any further if WI is empty
|
# Dont go any further if WI is empty
|
||||||
if(len(vars.worldinfo) == 0):
|
if(len(vars.worldinfo) == 0):
|
||||||
return "", set()
|
return "", set()
|
||||||
|
|
||||||
# Cache actions length
|
# Cache actions length
|
||||||
ln = len(vars.actions)
|
ln = len(actions)
|
||||||
|
|
||||||
# Don't bother calculating action history if widepth is 0
|
# Don't bother calculating action history if widepth is 0
|
||||||
if(vars.widepth > 0 and scan_story):
|
if(vars.widepth > 0 and scan_story):
|
||||||
|
@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx
|
||||||
if(ln > 0):
|
if(ln > 0):
|
||||||
chunks = collections.deque()
|
chunks = collections.deque()
|
||||||
i = 0
|
i = 0
|
||||||
for key in reversed(vars.actions):
|
for key in reversed(actions):
|
||||||
chunk = vars.actions[key]
|
chunk = actions[key]
|
||||||
chunks.appendleft(chunk)
|
chunks.appendleft(chunk)
|
||||||
i += 1
|
i += 1
|
||||||
if(i == depth):
|
if(i == depth):
|
||||||
|
|
Loading…
Reference in New Issue