mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Fix a bug with dynamic WI scan when using a soft prompt
The problem was that when a soft prompt is being used, the dynamic scanning criteria searches a different set of tokens for world info keys than the `_generate()` function, which results in generation loops when a world info key appears in the former set of tokens but not the latter.
This commit is contained in:
		
							
								
								
									
										22
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -806,13 +806,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|                 self, | ||||
|                 tokenizer, | ||||
|                 excluded_world_info: List[Set], | ||||
|                 head_length: int, | ||||
|             ): | ||||
|                 self.regeneration_required = False | ||||
|                 self.halt = False | ||||
|                 self.tokenizer = tokenizer | ||||
|                 self.excluded_world_info = excluded_world_info | ||||
|                 self.head_length = head_length | ||||
|             def __call__( | ||||
|                 self, | ||||
|                 input_ids: torch.LongTensor, | ||||
| @@ -838,10 +836,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|  | ||||
|                 if(not vars.dynamicscan): | ||||
|                     return self.regeneration_required or self.halt | ||||
|                 tail = input_ids[..., self.head_length:] | ||||
|                 tail = input_ids[..., -vars.generated_tkns:] | ||||
|                 for i, t in enumerate(tail): | ||||
|                     decoded = tokenizer.decode(t) | ||||
|                     _, found = checkworldinfo(decoded, force_use_txt=True) | ||||
|                     _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) | ||||
|                     found -= self.excluded_world_info[i] | ||||
|                     if(len(found) != 0): | ||||
|                         self.regeneration_required = True | ||||
| @@ -854,7 +852,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|             self.kai_scanner = DynamicWorldInfoScanCriteria( | ||||
|                 tokenizer=tokenizer, | ||||
|                 excluded_world_info=self.kai_scanner_excluded_world_info, | ||||
|                 head_length=self.kai_scanner_head_length, | ||||
|             ) | ||||
|             stopping_criteria.insert(0, self.kai_scanner) | ||||
|             return stopping_criteria | ||||
| @@ -2615,7 +2612,6 @@ def _generate(txt, minimum, maximum, found_entries): | ||||
|     else: | ||||
|         gen_in = gen_in.to('cpu') | ||||
|  | ||||
|     model.kai_scanner_head_length = gen_in.shape[-1] | ||||
|     model.kai_scanner_excluded_world_info = found_entries | ||||
|  | ||||
|     vars._actions = vars.actions | ||||
| @@ -2654,7 +2650,7 @@ def _generate(txt, minimum, maximum, found_entries): | ||||
|             encoded = [] | ||||
|             for i in range(vars.numseqs): | ||||
|                 txt = tokenizer.decode(genout[i, -already_generated:]) | ||||
|                 winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) | ||||
|                 winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) | ||||
|                 found_entries[i].update(_found_entries) | ||||
|                 txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) | ||||
|                 encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device)) | ||||
| @@ -2679,7 +2675,6 @@ def _generate(txt, minimum, maximum, found_entries): | ||||
|             minimum += diff | ||||
|             maximum += diff | ||||
|             gen_in = genout | ||||
|             model.kai_scanner_head_length = encoded.shape[-1] | ||||
|             numseqs = 1 | ||||
|      | ||||
|     return genout, already_generated | ||||
| @@ -3412,15 +3407,18 @@ def deletewifolder(uid): | ||||
| #==================================================================# | ||||
| #  Look for WI keys in text to generator  | ||||
| #==================================================================# | ||||
| def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True): | ||||
| def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None): | ||||
|     original_txt = txt | ||||
|  | ||||
|     if(actions is None): | ||||
|         actions = vars.actions | ||||
|  | ||||
|     # Dont go any further if WI is empty | ||||
|     if(len(vars.worldinfo) == 0): | ||||
|         return "", set() | ||||
|      | ||||
|     # Cache actions length | ||||
|     ln = len(vars.actions) | ||||
|     ln = len(actions) | ||||
|      | ||||
|     # Don't bother calculating action history if widepth is 0 | ||||
|     if(vars.widepth > 0 and scan_story): | ||||
| @@ -3434,8 +3432,8 @@ def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_tx | ||||
|         if(ln > 0): | ||||
|             chunks = collections.deque() | ||||
|             i = 0 | ||||
|             for key in reversed(vars.actions): | ||||
|                 chunk = vars.actions[key] | ||||
|             for key in reversed(actions): | ||||
|                 chunk = actions[key] | ||||
|                 chunks.appendleft(chunk) | ||||
|                 i += 1 | ||||
|                 if(i == depth): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann