Implement support for sampler order in the backend code
This commit is contained in:
parent
a273a5ebc4
commit
2d3db7b4ba
27
aiserver.py
27
aiserver.py
|
@ -306,6 +306,7 @@ class vars:
|
||||||
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
||||||
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
||||||
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
||||||
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
chatmode = False
|
chatmode = False
|
||||||
chatname = "You"
|
chatname = "You"
|
||||||
adventure = False
|
adventure = False
|
||||||
|
@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||||
|
|
||||||
|
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||||
|
def __init__(self, beams: int = 1, **kwargs):
|
||||||
|
self.__warper_list: List[LogitsWarper] = []
|
||||||
|
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||||
|
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||||
|
for k in vars.sampler_order:
|
||||||
|
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||||
warper_list = LogitsProcessorList()
|
return KoboldLogitsWarperList(beams=beams)
|
||||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
|
||||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
|
||||||
return warper_list
|
|
||||||
|
|
||||||
def new_sample(self, *args, **kwargs):
|
def new_sample(self, *args, **kwargs):
|
||||||
assert kwargs.pop("logits_warper", None) is not None
|
assert kwargs.pop("logits_warper", None) is not None
|
||||||
|
@ -1816,6 +1825,7 @@ else:
|
||||||
|
|
||||||
def tpumtjgenerate_settings_callback() -> dict:
|
def tpumtjgenerate_settings_callback() -> dict:
|
||||||
return {
|
return {
|
||||||
|
"sampler_order": vars.sampler_order,
|
||||||
"top_p": float(vars.top_p),
|
"top_p": float(vars.top_p),
|
||||||
"temp": float(vars.temp),
|
"temp": float(vars.temp),
|
||||||
"top_k": int(vars.top_k),
|
"top_k": int(vars.top_k),
|
||||||
|
@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
rprange=vars.rep_pen_range,
|
rprange=vars.rep_pen_range,
|
||||||
soft_embeddings=vars.sp,
|
soft_embeddings=vars.sp,
|
||||||
soft_tokens=soft_tokens,
|
soft_tokens=soft_tokens,
|
||||||
|
sampler_order=vars.sampler_order,
|
||||||
)
|
)
|
||||||
past = genout
|
past = genout
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
|
|
|
@ -65,6 +65,7 @@ def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List
|
||||||
|
|
||||||
def settings_callback() -> dict:
|
def settings_callback() -> dict:
|
||||||
return {
|
return {
|
||||||
|
"sampler_order": utils.default_sampler_order.copy(),
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
"temp": 0.5,
|
"temp": 0.5,
|
||||||
"top_k": 0,
|
"top_k": 0,
|
||||||
|
@ -159,7 +160,7 @@ def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generat
|
||||||
logits[tokens] = penalty_logits
|
logits[tokens] = penalty_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
def kobold_sample_dynamic(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
|
@ -181,8 +182,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if top_k > 0:
|
|
||||||
logits = top_k_filter(logits)
|
|
||||||
# Top-a (remove all tokens that have softmax probability less than
|
# Top-a (remove all tokens that have softmax probability less than
|
||||||
# a*m^2 where m is the maximum softmax probability)
|
# a*m^2 where m is the maximum softmax probability)
|
||||||
def top_a_filter(logits):
|
def top_a_filter(logits):
|
||||||
|
@ -195,8 +194,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
probs_max = probabilities.max()
|
probs_max = probabilities.max()
|
||||||
# Remove tokens
|
# Remove tokens
|
||||||
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
|
||||||
if top_a > 0.0:
|
|
||||||
logits = top_a_filter(logits)
|
|
||||||
# Top-p (after sorting the remaining tokens again in descending order of
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
# logit, remove the ones that have cumulative softmax probability
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
# greater than p)
|
# greater than p)
|
||||||
|
@ -222,8 +219,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if top_p < 1.0:
|
|
||||||
logits = top_p_filter(logits)
|
|
||||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
# except it's the "cumulative normalized absolute second finite
|
# except it's the "cumulative normalized absolute second finite
|
||||||
# differences of the softmax probabilities" instead of just the
|
# differences of the softmax probabilities" instead of just the
|
||||||
|
@ -262,8 +257,6 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -np.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
if tfs < 1.0:
|
|
||||||
logits = tail_free_filter(logits)
|
|
||||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
def typical_filter(logits):
|
def typical_filter(logits):
|
||||||
# Compute softmax probabilities and the natural logarithms of them
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
@ -293,10 +286,16 @@ def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, ty
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return np.where(indices_to_remove, -jnp.inf, logits)
|
return np.where(indices_to_remove, -jnp.inf, logits)
|
||||||
if typical < 1.0:
|
|
||||||
logits = typical_filter(logits)
|
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
logits /= temp
|
def temp_filter(logits):
|
||||||
|
return logits / temp
|
||||||
|
for k in sampler_order:
|
||||||
|
if k == 0 and top_k > 0: logits = top_k_filter(logits)
|
||||||
|
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
|
||||||
|
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
|
||||||
|
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
|
||||||
|
if k == 4 and typical < 1.0: logits = typical_filter(logits)
|
||||||
|
if k == 5 and temp != 1.0: logits = temp_filter(logits)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
@ -347,7 +346,7 @@ def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generate
|
||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
return logits.at[tokens].set(penalty_logits)
|
||||||
|
|
||||||
def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
def kobold_sample_static(key, logits, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply a series of 6 filters
|
This gets called by generate_loop_fn to apply a series of 6 filters
|
||||||
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
|
||||||
|
@ -369,7 +368,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
|
||||||
# Top-a (remove all tokens that have softmax probability less than
|
# Top-a (remove all tokens that have softmax probability less than
|
||||||
# a*m^2 where m is the maximum softmax probability)
|
# a*m^2 where m is the maximum softmax probability)
|
||||||
def top_a_filter(logits):
|
def top_a_filter(logits):
|
||||||
|
@ -382,7 +380,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
probs_max = probabilities.max()
|
probs_max = probabilities.max()
|
||||||
# Remove tokens
|
# Remove tokens
|
||||||
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(top_a > 0.0, top_a_filter, lambda x: x, logits)
|
|
||||||
# Top-p (after sorting the remaining tokens again in descending order of
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
# logit, remove the ones that have cumulative softmax probability
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
# greater than p)
|
# greater than p)
|
||||||
|
@ -408,7 +405,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
|
|
||||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
# except it's the "cumulative normalized absolute second finite
|
# except it's the "cumulative normalized absolute second finite
|
||||||
# differences of the softmax probabilities" instead of just the
|
# differences of the softmax probabilities" instead of just the
|
||||||
|
@ -447,7 +443,6 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
|
||||||
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
|
||||||
def typical_filter(logits):
|
def typical_filter(logits):
|
||||||
# Compute softmax probabilities and the natural logarithms of them
|
# Compute softmax probabilities and the natural logarithms of them
|
||||||
|
@ -476,11 +471,16 @@ def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typ
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
||||||
logits = jax.lax.cond(typical < 1.0, typical_filter, lambda x: x, logits)
|
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
def temp_filter(logits):
|
def temp_filter(logits):
|
||||||
return logits / temp
|
return logits / temp
|
||||||
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
|
for k in sampler_order:
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
|
||||||
|
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
|
@ -842,8 +842,12 @@ def infer_static(
|
||||||
gen_len=80,
|
gen_len=80,
|
||||||
soft_embeddings: Optional[np.array] = None,
|
soft_embeddings: Optional[np.array] = None,
|
||||||
soft_tokens: Optional[np.array] = None,
|
soft_tokens: Optional[np.array] = None,
|
||||||
|
sampler_order: Optional[List[int]] = None,
|
||||||
) -> List[np.array]:
|
) -> List[np.array]:
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
|
if sampler_order is None:
|
||||||
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
|
sampler_order = np.uint32(sampler_order)
|
||||||
total_batch = 1
|
total_batch = 1
|
||||||
tokens = context
|
tokens = context
|
||||||
if(soft_tokens is not None):
|
if(soft_tokens is not None):
|
||||||
|
@ -854,6 +858,7 @@ def infer_static(
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
samples = []
|
samples = []
|
||||||
batched_generator_params = {
|
batched_generator_params = {
|
||||||
|
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
|
||||||
"temp": temp * np.ones(total_batch),
|
"temp": temp * np.ones(total_batch),
|
||||||
"top_p": top_p * np.ones(total_batch),
|
"top_p": top_p * np.ones(total_batch),
|
||||||
"tfs": tfs * np.ones(total_batch),
|
"tfs": tfs * np.ones(total_batch),
|
||||||
|
@ -1015,6 +1020,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||||
global thread_resources_env, seq, tokenizer, network, params
|
global thread_resources_env, seq, tokenizer, network, params
|
||||||
|
|
||||||
|
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
|
||||||
|
vars.sampler_order = utils.default_sampler_order.copy()
|
||||||
|
|
||||||
default_params = {
|
default_params = {
|
||||||
"compat": "j",
|
"compat": "j",
|
||||||
"layers": 28,
|
"layers": 28,
|
||||||
|
|
2
utils.py
2
utils.py
|
@ -20,6 +20,8 @@ from_pretrained_index_filename: Optional[str] = None
|
||||||
from_pretrained_kwargs = {}
|
from_pretrained_kwargs = {}
|
||||||
bar = None
|
bar = None
|
||||||
|
|
||||||
|
default_sampler_order = [0, 1, 2, 3, 4, 5]
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
# at least x seconds have passed without the function being called
|
# at least x seconds have passed without the function being called
|
||||||
|
|
Loading…
Reference in New Issue