diff --git a/aiserver.py b/aiserver.py index 3e9d8290..e151a32b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -39,7 +39,7 @@ import functools import traceback from collections.abc import Iterable from collections import OrderedDict -from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List +from typing import Any, Callable, Optional, TypeVar, Tuple, Union, Dict, Set, List import requests import html @@ -1341,6 +1341,67 @@ def patch_transformers(): RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__ RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__ + class PhraseBiasLogitsProcessor(LogitsProcessor): + def __init__(self): + pass + + def _rindex(self, lst: List, target) -> Optional[int]: + for index, item in enumerate(reversed(lst)): + if item == target: + return len(lst) - index - 1 + return None + + def _find_intersection(self, big: List, small: List) -> int: + # Find the intersection of the end of "big" and the beginning of + # "small". A headache to think about, personally. Returns the index + # into "small" where the two stop intersecting. + start = self._rindex(big, small[0]) + + # No progress into the token sequence, bias the first one. + if not start: + return 0 + + for i in range(len(small)): + try: + big_i = big[start + i] + except IndexError: + return i + + # It's completed :^) + return 0 + + def _get_biased_tokens(self, input_ids: List) -> Dict: + # TODO: Different "bias slopes"? + + ret = {} + for phrase, _bias in koboldai_vars.biases.items(): + bias_score, completion_threshold = _bias + # TODO: Cache these tokens, invalidate when model or bias is + # changed. + token_seq = tokenizer.encode(phrase) + bias_index = self._find_intersection(input_ids, token_seq) + + # Ensure completion after completion_threshold tokens + if bias_index + 1 > completion_threshold: + bias_score = 999 + + token_to_bias = token_seq[bias_index] + ret[token_to_bias] = bias_score + return ret + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + + scores_shape = scores.shape + + for batch in range(scores_shape[0]): + for token, bias in self._get_biased_tokens(input_ids[batch]).items(): + scores[batch][token] += bias + + return scores + + class LuaLogitsProcessor(LogitsProcessor): def __init__(self): @@ -1373,6 +1434,7 @@ def patch_transformers(): def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors.insert(0, LuaLogitsProcessor()) + processors.insert(0, PhraseBiasLogitsProcessor()) return processors new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor diff --git a/koboldai_settings.py b/koboldai_settings.py index 14e8e4b6..5310b9e8 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -249,7 +249,7 @@ class model_settings(settings): self.selected_preset = "" self.uid_presets = [] self.default_preset = {} - self.biases = {} # should look like {"phrase": [percent, max_occurances]} + self.biases = {} # should look like {"phrase": [score, completion_threshold]} #dummy class to eat the tqdm output class ignore_tqdm(object): diff --git a/static/koboldai.css b/static/koboldai.css index b4f3c186..95ffad98 100644 --- a/static/koboldai.css +++ b/static/koboldai.css @@ -363,12 +363,12 @@ width: 100%; } -.bias_percent { +.bias_score { grid-area: percent; margin-right: 5px; } -.bias_max { +.bias_comp_threshold { grid-area: max; margin-right: 5px; } @@ -409,7 +409,7 @@ font-size: small; } -.bias_header_percent { +.bias_header_score { grid-area: percent; font-size: small; } diff --git a/static/koboldai.js b/static/koboldai.js index bc1ee1a5..a1d2801b 100644 --- a/static/koboldai.js +++ b/static/koboldai.js @@ -1252,16 +1252,16 @@ function save_bias(item) { //get all of our biases for (bias of document.getElementsByClassName("bias")) { //phrase - phrase = bias.querySelector(".bias_phrase").querySelector("input").value; + var phrase = bias.querySelector(".bias_phrase").querySelector("input").value; //percent - percent = parseFloat(bias.querySelector(".bias_percent").querySelector("input").value); + var percent = parseFloat(bias.querySelector(".bias_score").querySelector("input").value); - //max occurance - max_occurance = parseInt(bias.querySelector(".bias_max").querySelector("input").value); + //completion threshold + var comp_threshold = parseInt(bias.querySelector(".bias_comp_threshold").querySelector("input").value); if (phrase != "") { - biases[phrase] = [percent, max_occurance]; + biases[phrase] = [percent, comp_threshold]; } else { //mark that we have a blank line, or delete it if we have more than one if (have_blank) { @@ -1276,8 +1276,8 @@ function save_bias(item) { console.log("Create new bias line"); bias_line = document.getElementsByClassName("bias")[0].cloneNode(true); bias_line.querySelector(".bias_phrase").querySelector("input").value = ""; - bias_line.querySelector(".bias_percent").querySelector("input").value = 1; - bias_line.querySelector(".bias_max").querySelector("input").value = 50; + bias_line.querySelector(".bias_score").querySelector("input").value = 0; + bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = 50; document.getElementById('biasing').append(bias_line); } diff --git a/templates/settings flyout.html b/templates/settings flyout.html index fe08944c..62c08854 100644 --- a/templates/settings flyout.html +++ b/templates/settings flyout.html @@ -57,35 +57,35 @@