Add phrase biasing backend

This commit is contained in:
somebody
2022-08-08 11:33:17 -05:00
parent 6675769c0b
commit 8d54b8d08b
5 changed files with 85 additions and 23 deletions

View File

@@ -39,7 +39,7 @@ import functools
import traceback
from collections.abc import Iterable
from collections import OrderedDict
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
from typing import Any, Callable, Optional, TypeVar, Tuple, Union, Dict, Set, List
import requests
import html
@@ -1341,6 +1341,67 @@ def patch_transformers():
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
class PhraseBiasLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def _rindex(self, lst: List, target) -> Optional[int]:
for index, item in enumerate(reversed(lst)):
if item == target:
return len(lst) - index - 1
return None
def _find_intersection(self, big: List, small: List) -> int:
# Find the intersection of the end of "big" and the beginning of
# "small". A headache to think about, personally. Returns the index
# into "small" where the two stop intersecting.
start = self._rindex(big, small[0])
# No progress into the token sequence, bias the first one.
if not start:
return 0
for i in range(len(small)):
try:
big_i = big[start + i]
except IndexError:
return i
# It's completed :^)
return 0
def _get_biased_tokens(self, input_ids: List) -> Dict:
# TODO: Different "bias slopes"?
ret = {}
for phrase, _bias in koboldai_vars.biases.items():
bias_score, completion_threshold = _bias
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
token_seq = tokenizer.encode(phrase)
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
if bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
ret[token_to_bias] = bias_score
return ret
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
scores_shape = scores.shape
for batch in range(scores_shape[0]):
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
scores[batch][token] += bias
return scores
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
@@ -1373,6 +1434,7 @@ def patch_transformers():
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.insert(0, PhraseBiasLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor

View File

@@ -249,7 +249,7 @@ class model_settings(settings):
self.selected_preset = ""
self.uid_presets = []
self.default_preset = {}
self.biases = {} # should look like {"phrase": [percent, max_occurances]}
self.biases = {} # should look like {"phrase": [score, completion_threshold]}
#dummy class to eat the tqdm output
class ignore_tqdm(object):

View File

@@ -363,12 +363,12 @@
width: 100%;
}
.bias_percent {
.bias_score {
grid-area: percent;
margin-right: 5px;
}
.bias_max {
.bias_comp_threshold {
grid-area: max;
margin-right: 5px;
}
@@ -409,7 +409,7 @@
font-size: small;
}
.bias_header_percent {
.bias_header_score {
grid-area: percent;
font-size: small;
}

View File

@@ -1252,16 +1252,16 @@ function save_bias(item) {
//get all of our biases
for (bias of document.getElementsByClassName("bias")) {
//phrase
phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
var phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
//percent
percent = parseFloat(bias.querySelector(".bias_percent").querySelector("input").value);
var percent = parseFloat(bias.querySelector(".bias_score").querySelector("input").value);
//max occurance
max_occurance = parseInt(bias.querySelector(".bias_max").querySelector("input").value);
//completion threshold
var comp_threshold = parseInt(bias.querySelector(".bias_comp_threshold").querySelector("input").value);
if (phrase != "") {
biases[phrase] = [percent, max_occurance];
biases[phrase] = [percent, comp_threshold];
} else {
//mark that we have a blank line, or delete it if we have more than one
if (have_blank) {
@@ -1276,8 +1276,8 @@ function save_bias(item) {
console.log("Create new bias line");
bias_line = document.getElementsByClassName("bias")[0].cloneNode(true);
bias_line.querySelector(".bias_phrase").querySelector("input").value = "";
bias_line.querySelector(".bias_percent").querySelector("input").value = 1;
bias_line.querySelector(".bias_max").querySelector("input").value = 50;
bias_line.querySelector(".bias_score").querySelector("input").value = 0;
bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = 50;
document.getElementById('biasing').append(bias_line);
}

View File

@@ -57,35 +57,35 @@
<div id="biasing">
<div class="bias_header">
<div class="bias_header_phrase">Phrase</div>
<div class="bias_header_percent">Percent Chance</div>
<div class="bias_header_max">Max Occurance</div>
<div class="bias_header_score">Score</div>
<div class="bias_header_comp_threshold">Completion Threshold</div>
</div>
<div class="bias">
<div class="bias_phrase">
<input type=text placeholder="Word or Phrase to Bias" onchange="save_bias(this);"/>
</div>
<div class="bias_percent">
<div class="bias_score">
<div class="bias_slider">
<div class="bias_slider_bar">
<input type="range" min="0" max="1" step="0.01" value="1" class="setting_item_input"
<input type="range" min="-12" max="12" step="0.01" value="0" class="setting_item_input"
oninput="update_bias_slider_value(this);"
onchange="save_bias(this);"/>
</div>
<div class="bias_slider_min">-12.00</div>
<div class="bias_slider_cur">0</div>
<div class="bias_slider_max">12.00</div>
</div>
</div>
<div class="bias_comp_threshold">
<div class="bias_slider">
<div class="bias_slider_bar">
<input type="range" min="0" max="10" step="1" value="10" class="setting_item_input"
oninput="update_bias_slider_value(this);"
onchange="save_bias(this);"/>
</div>
<div class="bias_slider_min">0</div>
<div class="bias_slider_cur">1</div>
<div class="bias_slider_max">1</div>
</div>
</div>
<div class="bias_max">
<div class="bias_slider">
<div class="bias_slider_bar">
<input type="range" min="0" max="50" step="1" value="50" class="setting_item_input"
oninput="update_bias_slider_value(this);"
onchange="save_bias(this);"/>
</div>
<div class="bias_slider_min">0</div>
<div class="bias_slider_cur">50</div>
<div class="bias_slider_max">50</div>
<div class="bias_slider_cur">10</div>
<div class="bias_slider_max">10</div>
</div>
</div>
</div>