mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Implement support for sampler order in the backend code
This commit is contained in:
		
							
								
								
									
										27
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -306,6 +306,7 @@ class vars: | ||||
|     acregex_ui  = re.compile(r'^ *(>.*)$', re.MULTILINE)    # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses) | ||||
|     comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)')  # Pattern for matching comments to remove them before sending them to the AI | ||||
|     comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)')  # Pattern for matching comments in the editor | ||||
|     sampler_order = utils.default_sampler_order.copy() | ||||
|     chatmode    = False | ||||
|     chatname    = "You" | ||||
|     adventure   = False | ||||
| @@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go | ||||
|         new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor | ||||
|         transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor | ||||
|  | ||||
|         class KoboldLogitsWarperList(LogitsProcessorList): | ||||
|             def __init__(self, beams: int = 1, **kwargs): | ||||
|                 self.__warper_list: List[LogitsWarper] = [] | ||||
|                 self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) | ||||
|                 self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|                 self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|                 self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|                 self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|                 self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) | ||||
|  | ||||
|             def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): | ||||
|                 for k in vars.sampler_order: | ||||
|                     scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) | ||||
|                 return scores | ||||
|  | ||||
|         def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList: | ||||
|             warper_list = LogitsProcessorList() | ||||
|             warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             warper_list.append(TemperatureLogitsWarper(temperature=0.5)) | ||||
|             return warper_list | ||||
|             return KoboldLogitsWarperList(beams=beams) | ||||
|          | ||||
|         def new_sample(self, *args, **kwargs): | ||||
|             assert kwargs.pop("logits_warper", None) is not None | ||||
| @@ -1816,6 +1825,7 @@ else: | ||||
|      | ||||
|     def tpumtjgenerate_settings_callback() -> dict: | ||||
|         return { | ||||
|             "sampler_order": vars.sampler_order, | ||||
|             "top_p": float(vars.top_p), | ||||
|             "temp": float(vars.temp), | ||||
|             "top_k": int(vars.top_k), | ||||
| @@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | ||||
|                 rprange=vars.rep_pen_range, | ||||
|                 soft_embeddings=vars.sp, | ||||
|                 soft_tokens=soft_tokens, | ||||
|                 sampler_order=vars.sampler_order, | ||||
|             ) | ||||
|             past = genout | ||||
|             for i in range(vars.numseqs): | ||||
|   | ||||
| @@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List | ||||
|  | ||||
| def settings_callback() -> dict: | ||||
|     return { | ||||
|         "sampler_order": utils.default_sampler_order.copy(), | ||||
|         "top_p": 0.9, | ||||
|         "temp": 0.5, | ||||
|         "top_k": 0, | ||||
| @@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat | ||||
|     logits[tokens] = penalty_logits | ||||
|     return logits | ||||
|  | ||||
| def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||
| def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||
|     ''' | ||||
|     This gets called by generate_loop_fn to apply a series of 6 filters | ||||
|     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) | ||||
| @@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if top_k > 0: | ||||
|         logits = top_k_filter(logits) | ||||
|     # Top-a (remove all tokens that have softmax probability less than | ||||
|     # a*m^2 where m is the maximum softmax probability) | ||||
|     def top_a_filter(logits): | ||||
| @@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty | ||||
|         probs_max = probabilities.max() | ||||
|         # Remove tokens | ||||
|         return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits) | ||||
|     if top_a > 0.0: | ||||
|         logits = top_a_filter(logits) | ||||
|     # Top-p (after sorting the remaining tokens again in descending order of | ||||
|     # logit, remove the ones that have cumulative softmax probability | ||||
|     # greater than p) | ||||
| @@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if top_p < 1.0: | ||||
|         logits = top_p_filter(logits) | ||||
|     # Tail free sampling (basically top-p a second time on remaining tokens | ||||
|     # except it's the "cumulative normalized absolute second finite | ||||
|     # differences of the softmax probabilities" instead of just the | ||||
| @@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if tfs < 1.0: | ||||
|         logits = tail_free_filter(logits) | ||||
|     # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) | ||||
|     def typical_filter(logits): | ||||
|         # Compute softmax probabilities and the natural logarithms of them | ||||
| @@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return np.where(indices_to_remove, -jnp.inf, logits) | ||||
|     if typical < 1.0: | ||||
|         logits = typical_filter(logits) | ||||
|     # Temperature (just divide the logits by the temperature) | ||||
|     logits /= temp | ||||
|     def temp_filter(logits): | ||||
|         return logits / temp | ||||
|     for k in sampler_order: | ||||
|         if k == 0 and top_k > 0: logits = top_k_filter(logits) | ||||
|         if k == 1 and top_a > 0.0: logits = top_a_filter(logits) | ||||
|         if k == 2 and top_p < 1.0: logits = top_p_filter(logits) | ||||
|         if k == 3 and tfs < 1.0: logits = tail_free_filter(logits) | ||||
|         if k == 4 and typical < 1.0: logits = typical_filter(logits) | ||||
|         if k == 5 and temp != 1.0: logits = temp_filter(logits) | ||||
|     # Finally, pick one token using the softmax thingy again (it gives | ||||
|     # an array whose elements sum to 1 so it can be used nicely as a | ||||
|     # probability distribution) | ||||
| @@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate | ||||
|     # positions in the logits array | ||||
|     return logits.at[tokens].set(penalty_logits) | ||||
|  | ||||
| def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||
| def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||
|     ''' | ||||
|     This gets called by generate_loop_fn to apply a series of 6 filters | ||||
|     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) | ||||
| @@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) | ||||
|     # Top-a (remove all tokens that have softmax probability less than | ||||
|     # a*m^2 where m is the maximum softmax probability) | ||||
|     def top_a_filter(logits): | ||||
| @@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ | ||||
|         probs_max = probabilities.max() | ||||
|         # Remove tokens | ||||
|         return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits) | ||||
|     # Top-p (after sorting the remaining tokens again in descending order of | ||||
|     # logit, remove the ones that have cumulative softmax probability | ||||
|     # greater than p) | ||||
| @@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) | ||||
|     # Tail free sampling (basically top-p a second time on remaining tokens | ||||
|     # except it's the "cumulative normalized absolute second finite | ||||
|     # differences of the softmax probabilities" instead of just the | ||||
| @@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) | ||||
|     # Typical sampling (https://arxiv.org/pdf/2202.00666.pdf) | ||||
|     def typical_filter(logits): | ||||
|         # Compute softmax probabilities and the natural logarithms of them | ||||
| @@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits) | ||||
|     # Temperature (just divide the logits by the temperature) | ||||
|     def temp_filter(logits): | ||||
|         return logits / temp | ||||
|     logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) | ||||
|     for k in sampler_order: | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits) | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits) | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits) | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits) | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits) | ||||
|         logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits) | ||||
|     # Finally, pick one token using the softmax thingy again (it gives | ||||
|     # an array whose elements sum to 1 so it can be used nicely as a | ||||
|     # probability distribution) | ||||
| @@ -842,8 +842,12 @@ def infer_static( | ||||
|     gen_len=80, | ||||
|     soft_embeddings: Optional[np.array] = None, | ||||
|     soft_tokens: Optional[np.array] = None, | ||||
|     sampler_order: Optional[List[int]] = None, | ||||
| ) -> List[np.array]: | ||||
|     maps.thread_resources.env = thread_resources_env | ||||
|     if sampler_order is None: | ||||
|         sampler_order = utils.default_sampler_order.copy() | ||||
|     sampler_order = np.uint32(sampler_order) | ||||
|     total_batch = 1 | ||||
|     tokens = context | ||||
|     if(soft_tokens is not None): | ||||
| @@ -854,6 +858,7 @@ def infer_static( | ||||
|     batched_tokens = np.array([padded_tokens] * total_batch) | ||||
|     samples = [] | ||||
|     batched_generator_params = { | ||||
|         "sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0), | ||||
|         "temp": temp * np.ones(total_batch), | ||||
|         "top_p": top_p * np.ones(total_batch), | ||||
|         "tfs": tfs * np.ones(total_batch), | ||||
| @@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2): | ||||
| def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None: | ||||
|     global thread_resources_env, seq, tokenizer, network, params | ||||
|  | ||||
|     if not hasattr(vars, "sampler_order") or not vars.sampler_order: | ||||
|         vars.sampler_order = utils.default_sampler_order.copy() | ||||
|  | ||||
|     default_params = { | ||||
|         "compat": "j", | ||||
|         "layers": 28, | ||||
|   | ||||
							
								
								
									
										2
									
								
								utils.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								utils.py
									
									
									
									
									
								
							| @@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None | ||||
| from_pretrained_kwargs = {} | ||||
| bar = None | ||||
|  | ||||
| default_sampler_order = [0, 1, 2, 3, 4, 5] | ||||
|  | ||||
| #==================================================================# | ||||
| # Decorator to prevent a function's actions from being run until | ||||
| # at least x seconds have passed without the function being called | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann