mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add phrase biasing backend
This commit is contained in:
64
aiserver.py
64
aiserver.py
@@ -39,7 +39,7 @@ import functools
|
||||
import traceback
|
||||
from collections.abc import Iterable
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
from typing import Any, Callable, Optional, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
@@ -1341,6 +1341,67 @@ def patch_transformers():
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _rindex(self, lst: List, target) -> Optional[int]:
|
||||
for index, item in enumerate(reversed(lst)):
|
||||
if item == target:
|
||||
return len(lst) - index - 1
|
||||
return None
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
# Find the intersection of the end of "big" and the beginning of
|
||||
# "small". A headache to think about, personally. Returns the index
|
||||
# into "small" where the two stop intersecting.
|
||||
start = self._rindex(big, small[0])
|
||||
|
||||
# No progress into the token sequence, bias the first one.
|
||||
if not start:
|
||||
return 0
|
||||
|
||||
for i in range(len(small)):
|
||||
try:
|
||||
big_i = big[start + i]
|
||||
except IndexError:
|
||||
return i
|
||||
|
||||
# It's completed :^)
|
||||
return 0
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
token_seq = tokenizer.encode(phrase)
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
if bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self):
|
||||
@@ -1373,6 +1434,7 @@ def patch_transformers():
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
processors.insert(0, PhraseBiasLogitsProcessor())
|
||||
return processors
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
@@ -249,7 +249,7 @@ class model_settings(settings):
|
||||
self.selected_preset = ""
|
||||
self.uid_presets = []
|
||||
self.default_preset = {}
|
||||
self.biases = {} # should look like {"phrase": [percent, max_occurances]}
|
||||
self.biases = {} # should look like {"phrase": [score, completion_threshold]}
|
||||
|
||||
#dummy class to eat the tqdm output
|
||||
class ignore_tqdm(object):
|
||||
|
@@ -363,12 +363,12 @@
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.bias_percent {
|
||||
.bias_score {
|
||||
grid-area: percent;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
.bias_max {
|
||||
.bias_comp_threshold {
|
||||
grid-area: max;
|
||||
margin-right: 5px;
|
||||
}
|
||||
@@ -409,7 +409,7 @@
|
||||
font-size: small;
|
||||
}
|
||||
|
||||
.bias_header_percent {
|
||||
.bias_header_score {
|
||||
grid-area: percent;
|
||||
font-size: small;
|
||||
}
|
||||
|
@@ -1252,16 +1252,16 @@ function save_bias(item) {
|
||||
//get all of our biases
|
||||
for (bias of document.getElementsByClassName("bias")) {
|
||||
//phrase
|
||||
phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
|
||||
var phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
|
||||
|
||||
//percent
|
||||
percent = parseFloat(bias.querySelector(".bias_percent").querySelector("input").value);
|
||||
var percent = parseFloat(bias.querySelector(".bias_score").querySelector("input").value);
|
||||
|
||||
//max occurance
|
||||
max_occurance = parseInt(bias.querySelector(".bias_max").querySelector("input").value);
|
||||
//completion threshold
|
||||
var comp_threshold = parseInt(bias.querySelector(".bias_comp_threshold").querySelector("input").value);
|
||||
|
||||
if (phrase != "") {
|
||||
biases[phrase] = [percent, max_occurance];
|
||||
biases[phrase] = [percent, comp_threshold];
|
||||
} else {
|
||||
//mark that we have a blank line, or delete it if we have more than one
|
||||
if (have_blank) {
|
||||
@@ -1276,8 +1276,8 @@ function save_bias(item) {
|
||||
console.log("Create new bias line");
|
||||
bias_line = document.getElementsByClassName("bias")[0].cloneNode(true);
|
||||
bias_line.querySelector(".bias_phrase").querySelector("input").value = "";
|
||||
bias_line.querySelector(".bias_percent").querySelector("input").value = 1;
|
||||
bias_line.querySelector(".bias_max").querySelector("input").value = 50;
|
||||
bias_line.querySelector(".bias_score").querySelector("input").value = 0;
|
||||
bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = 50;
|
||||
document.getElementById('biasing').append(bias_line);
|
||||
}
|
||||
|
||||
|
@@ -57,35 +57,35 @@
|
||||
<div id="biasing">
|
||||
<div class="bias_header">
|
||||
<div class="bias_header_phrase">Phrase</div>
|
||||
<div class="bias_header_percent">Percent Chance</div>
|
||||
<div class="bias_header_max">Max Occurance</div>
|
||||
<div class="bias_header_score">Score</div>
|
||||
<div class="bias_header_comp_threshold">Completion Threshold</div>
|
||||
</div>
|
||||
<div class="bias">
|
||||
<div class="bias_phrase">
|
||||
<input type=text placeholder="Word or Phrase to Bias" onchange="save_bias(this);"/>
|
||||
</div>
|
||||
<div class="bias_percent">
|
||||
<div class="bias_score">
|
||||
<div class="bias_slider">
|
||||
<div class="bias_slider_bar">
|
||||
<input type="range" min="0" max="1" step="0.01" value="1" class="setting_item_input"
|
||||
<input type="range" min="-12" max="12" step="0.01" value="0" class="setting_item_input"
|
||||
oninput="update_bias_slider_value(this);"
|
||||
onchange="save_bias(this);"/>
|
||||
</div>
|
||||
<div class="bias_slider_min">-12.00</div>
|
||||
<div class="bias_slider_cur">0</div>
|
||||
<div class="bias_slider_max">12.00</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="bias_comp_threshold">
|
||||
<div class="bias_slider">
|
||||
<div class="bias_slider_bar">
|
||||
<input type="range" min="0" max="10" step="1" value="10" class="setting_item_input"
|
||||
oninput="update_bias_slider_value(this);"
|
||||
onchange="save_bias(this);"/>
|
||||
</div>
|
||||
<div class="bias_slider_min">0</div>
|
||||
<div class="bias_slider_cur">1</div>
|
||||
<div class="bias_slider_max">1</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="bias_max">
|
||||
<div class="bias_slider">
|
||||
<div class="bias_slider_bar">
|
||||
<input type="range" min="0" max="50" step="1" value="50" class="setting_item_input"
|
||||
oninput="update_bias_slider_value(this);"
|
||||
onchange="save_bias(this);"/>
|
||||
</div>
|
||||
<div class="bias_slider_min">0</div>
|
||||
<div class="bias_slider_cur">50</div>
|
||||
<div class="bias_slider_max">50</div>
|
||||
<div class="bias_slider_cur">10</div>
|
||||
<div class="bias_slider_max">10</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
Reference in New Issue
Block a user