diff --git a/aiserver.py b/aiserver.py index 2b90b54f..bf3c6202 100644 --- a/aiserver.py +++ b/aiserver.py @@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme self, tokenizer, excluded_world_info: List[Set], - head_length: int, ): self.regeneration_required = False self.halt = False self.tokenizer = tokenizer self.excluded_world_info = excluded_world_info - self.head_length = head_length def __call__( self, input_ids: torch.LongTensor, @@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme if(not vars.dynamicscan): return self.regeneration_required or self.halt - tail = input_ids[..., self.head_length:] + tail = input_ids[..., -vars.generated_tkns:] for i, t in enumerate(tail): decoded = tokenizer.decode(t) - _, found = checkworldinfo(decoded, force_use_txt=True) + _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) found -= self.excluded_world_info[i] if(len(found) != 0): self.regeneration_required = True @@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme self.kai_scanner = DynamicWorldInfoScanCriteria( tokenizer=tokenizer, excluded_world_info=self.kai_scanner_excluded_world_info, - head_length=self.kai_scanner_head_length, ) stopping_criteria.insert(0, self.kai_scanner) return stopping_criteria @@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries): else: gen_in = gen_in.to('cpu') - model.kai_scanner_head_length = gen_in.shape[-1] model.kai_scanner_excluded_world_info = found_entries vars._actions = vars.actions @@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries): encoded = [] for i in range(vars.numseqs): txt = tokenizer.decode(genout[i, -already_generated:]) - winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) found_entries[i].update(_found_entries) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) @@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries): minimum += diff maximum += diff gen_in = genout - model.kai_scanner_head_length = encoded.shape[-1] numseqs = 1 return genout, already_generated @@ -3412,15 +3407,18 @@ def deletewifolder(uid): #==================================================================# # Look for WI keys in text to generator #==================================================================# -def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True): +def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None): original_txt = txt + if(actions is None): + actions = vars.actions + # Dont go any further if WI is empty if(len(vars.worldinfo) == 0): return "", set() # Cache actions length - ln = len(vars.actions) + ln = len(actions) # Don't bother calculating action history if widepth is 0 if(vars.widepth > 0 and scan_story): @@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx if(ln > 0): chunks = collections.deque() i = 0 - for key in reversed(vars.actions): - chunk = vars.actions[key] + for key in reversed(actions): + chunk = actions[key] chunks.appendleft(chunk) i += 1 if(i == depth):