mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge branch 'avril' into rep-pen-order
This commit is contained in:
		| @@ -1727,8 +1727,6 @@ def patch_transformers(): | |||||||
|     dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) |     dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) | ||||||
|     dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) |     dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0) | ||||||
|     dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) |     dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0) | ||||||
|     RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ |  | ||||||
|     RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ |  | ||||||
|  |  | ||||||
|     class LuaLogitsProcessor(LogitsProcessor): |     class LuaLogitsProcessor(LogitsProcessor): | ||||||
|  |  | ||||||
| @@ -1805,6 +1803,7 @@ def patch_transformers(): | |||||||
|             self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) |             self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||||
|             self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) |             self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))) | ||||||
|             self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) |             self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) | ||||||
|  |             self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) | ||||||
|  |  | ||||||
|         def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): |         def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs): | ||||||
|             for k in vars.sampler_order: |             for k in vars.sampler_order: | ||||||
| @@ -4617,7 +4616,7 @@ def _generate(txt, minimum, maximum, found_entries): | |||||||
|                 gen_in,  |                 gen_in,  | ||||||
|                 do_sample=True,  |                 do_sample=True,  | ||||||
|                 max_length=int(2e9), |                 max_length=int(2e9), | ||||||
|                 repetition_penalty=1.1, |                 repetition_penalty=1.0, | ||||||
|                 bad_words_ids=vars.badwordsids, |                 bad_words_ids=vars.badwordsids, | ||||||
|                 use_cache=True, |                 use_cache=True, | ||||||
|                 num_return_sequences=numseqs |                 num_return_sequences=numseqs | ||||||
|   | |||||||
| @@ -176,7 +176,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat | |||||||
|     logits[tokens] = penalty_logits |     logits[tokens] = penalty_logits | ||||||
|     return logits |     return logits | ||||||
|  |  | ||||||
| def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||||
|     ''' |     ''' | ||||||
|     This gets called by generate_loop_fn to apply a series of 6 filters |     This gets called by generate_loop_fn to apply a series of 6 filters | ||||||
|     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) |     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) | ||||||
| @@ -315,6 +315,7 @@ def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = Non | |||||||
|     # Finally, pick one token using the softmax thingy again (it gives |     # Finally, pick one token using the softmax thingy again (it gives | ||||||
|     # an array whose elements sum to 1 so it can be used nicely as a |     # an array whose elements sum to 1 so it can be used nicely as a | ||||||
|     # probability distribution) |     # probability distribution) | ||||||
|  |     logits = apply_repetition_penalty_dynamic(logits, *rpargs) | ||||||
|     return jax.random.categorical(key, logits, -1).astype(np.uint32) |     return jax.random.categorical(key, logits, -1).astype(np.uint32) | ||||||
|  |  | ||||||
| def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): | def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): | ||||||
| @@ -362,7 +363,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate | |||||||
|     # positions in the logits array |     # positions in the logits array | ||||||
|     return logits.at[tokens].set(penalty_logits) |     return logits.at[tokens].set(penalty_logits) | ||||||
|  |  | ||||||
| def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): | ||||||
|     ''' |     ''' | ||||||
|     This gets called by generate_loop_fn to apply a series of 6 filters |     This gets called by generate_loop_fn to apply a series of 6 filters | ||||||
|     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) |     to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature) | ||||||
| @@ -500,6 +501,7 @@ def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None | |||||||
|     # Finally, pick one token using the softmax thingy again (it gives |     # Finally, pick one token using the softmax thingy again (it gives | ||||||
|     # an array whose elements sum to 1 so it can be used nicely as a |     # an array whose elements sum to 1 so it can be used nicely as a | ||||||
|     # probability distribution) |     # probability distribution) | ||||||
|  |     logits = apply_repetition_penalty_static(logits, *rpargs) | ||||||
|     return jax.random.categorical(key, logits, -1).astype(jnp.uint32) |     return jax.random.categorical(key, logits, -1).astype(jnp.uint32) | ||||||
|  |  | ||||||
| pad_token_id = 50256 | pad_token_id = 50256 | ||||||
| @@ -513,17 +515,6 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ | |||||||
|         # Get the pseudo-random number generator key that will |         # Get the pseudo-random number generator key that will | ||||||
|         # be used by kobold_sample_dynamic to randomly pick a token |         # be used by kobold_sample_dynamic to randomly pick a token | ||||||
|         sample_key, new_key = jax.random.split(sample_key, num=2) |         sample_key, new_key = jax.random.split(sample_key, num=2) | ||||||
|         # Apply repetition penalty to all tokens that are |  | ||||||
|         # currently inside the "generated" array |  | ||||||
|         logits = apply_repetition_penalty_dynamic( |  | ||||||
|             logits, |  | ||||||
|             generated, |  | ||||||
|             repetition_penalty, |  | ||||||
|             generated_index,  |  | ||||||
|             gen_length, |  | ||||||
|             rpslope, |  | ||||||
|             rprange, |  | ||||||
|         ) |  | ||||||
|         # Remove any tokens in the badwords list by setting |         # Remove any tokens in the badwords list by setting | ||||||
|         # their logits to negative infinity which effectively |         # their logits to negative infinity which effectively | ||||||
|         # makes their probabilities of being chosen zero |         # makes their probabilities of being chosen zero | ||||||
| @@ -535,6 +526,14 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_ | |||||||
|         next_token = kobold_sample_dynamic( |         next_token = kobold_sample_dynamic( | ||||||
|             sample_key, |             sample_key, | ||||||
|             logits, |             logits, | ||||||
|  |             ( | ||||||
|  |                 generated, | ||||||
|  |                 repetition_penalty, | ||||||
|  |                 generated_index,  | ||||||
|  |                 gen_length, | ||||||
|  |                 rpslope, | ||||||
|  |                 rprange, | ||||||
|  |             ) | ||||||
|             **sampler_options, |             **sampler_options, | ||||||
|         ) |         ) | ||||||
|         # Remember what token was picked |         # Remember what token was picked | ||||||
| @@ -606,18 +605,6 @@ class PenalizingCausalTransformer(CausalTransformer): | |||||||
|                     assert logits.shape == (1, config["n_vocab"]) |                     assert logits.shape == (1, config["n_vocab"]) | ||||||
|                     # Flatten it into a 1D array to make it easier to use |                     # Flatten it into a 1D array to make it easier to use | ||||||
|                     logits = logits[0] |                     logits = logits[0] | ||||||
|                     # Apply repetition penalty to all tokens that are |  | ||||||
|                     # currently inside the "generated" array |  | ||||||
|                     if repetition_penalty is not None: |  | ||||||
|                         logits = apply_repetition_penalty_static( |  | ||||||
|                             logits, |  | ||||||
|                             generated, |  | ||||||
|                             repetition_penalty, |  | ||||||
|                             generated_index, |  | ||||||
|                             gen_length, |  | ||||||
|                             rpslope, |  | ||||||
|                             rprange, |  | ||||||
|                         ) |  | ||||||
|                     # Remove any tokens in the badwords list by setting |                     # Remove any tokens in the badwords list by setting | ||||||
|                     # their logits to negative infinity which effectively |                     # their logits to negative infinity which effectively | ||||||
|                     # makes their probabilities of being chosen zero |                     # makes their probabilities of being chosen zero | ||||||
| @@ -629,6 +616,14 @@ class PenalizingCausalTransformer(CausalTransformer): | |||||||
|                     next_token = kobold_sample_static( |                     next_token = kobold_sample_static( | ||||||
|                         sample_key, |                         sample_key, | ||||||
|                         logits, |                         logits, | ||||||
|  |                         ( | ||||||
|  |                             generated, | ||||||
|  |                             repetition_penalty, | ||||||
|  |                             generated_index, | ||||||
|  |                             gen_length, | ||||||
|  |                             rpslope, | ||||||
|  |                             rprange, | ||||||
|  |                         ), | ||||||
|                         **sampler_options, |                         **sampler_options, | ||||||
|                     ) |                     ) | ||||||
|                     # Remember what token was picked |                     # Remember what token was picked | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ import torch | |||||||
| from transformers import LogitsWarper, LogitsProcessor | from transformers import LogitsWarper, LogitsProcessor | ||||||
|  |  | ||||||
|  |  | ||||||
| class AdvancedRepetitionPenaltyLogitsProcessor(LogitsProcessor): | class AdvancedRepetitionPenaltyLogitsProcessor(LogitsWarper): | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         pass |         pass | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user