mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Implement support for sampler order in the backend code
This commit is contained in:
27
aiserver.py
27
aiserver.py
@ -306,6 +306,7 @@ class vars:
|
||||
acregex_ui = re.compile(r'^ *(>.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
|
||||
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
||||
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
||||
sampler_order = utils.default_sampler_order.copy()
|
||||
chatmode = False
|
||||
chatname = "You"
|
||||
adventure = False
|
||||
@ -1448,15 +1449,23 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __init__(self, beams: int = 1, **kwargs):
|
||||
self.__warper_list: List[LogitsWarper] = []
|
||||
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
|
||||
for k in vars.sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
|
||||
warper_list = LogitsProcessorList()
|
||||
warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
|
||||
warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
return warper_list
|
||||
return KoboldLogitsWarperList(beams=beams)
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
@ -1816,6 +1825,7 @@ else:
|
||||
|
||||
def tpumtjgenerate_settings_callback() -> dict:
|
||||
return {
|
||||
"sampler_order": vars.sampler_order,
|
||||
"top_p": float(vars.top_p),
|
||||
"temp": float(vars.temp),
|
||||
"top_k": int(vars.top_k),
|
||||
@ -3910,6 +3920,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
rprange=vars.rep_pen_range,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
sampler_order=vars.sampler_order,
|
||||
)
|
||||
past = genout
|
||||
for i in range(vars.numseqs):
|
||||
|
Reference in New Issue
Block a user