mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add phrase biasing backend
This commit is contained in:
64
aiserver.py
64
aiserver.py
@@ -39,7 +39,7 @@ import functools
|
|||||||
import traceback
|
import traceback
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
from typing import Any, Callable, Optional, TypeVar, Tuple, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import html
|
import html
|
||||||
@@ -1341,6 +1341,67 @@ def patch_transformers():
|
|||||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||||
|
|
||||||
|
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _rindex(self, lst: List, target) -> Optional[int]:
|
||||||
|
for index, item in enumerate(reversed(lst)):
|
||||||
|
if item == target:
|
||||||
|
return len(lst) - index - 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_intersection(self, big: List, small: List) -> int:
|
||||||
|
# Find the intersection of the end of "big" and the beginning of
|
||||||
|
# "small". A headache to think about, personally. Returns the index
|
||||||
|
# into "small" where the two stop intersecting.
|
||||||
|
start = self._rindex(big, small[0])
|
||||||
|
|
||||||
|
# No progress into the token sequence, bias the first one.
|
||||||
|
if not start:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
for i in range(len(small)):
|
||||||
|
try:
|
||||||
|
big_i = big[start + i]
|
||||||
|
except IndexError:
|
||||||
|
return i
|
||||||
|
|
||||||
|
# It's completed :^)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||||
|
# TODO: Different "bias slopes"?
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for phrase, _bias in koboldai_vars.biases.items():
|
||||||
|
bias_score, completion_threshold = _bias
|
||||||
|
# TODO: Cache these tokens, invalidate when model or bias is
|
||||||
|
# changed.
|
||||||
|
token_seq = tokenizer.encode(phrase)
|
||||||
|
bias_index = self._find_intersection(input_ids, token_seq)
|
||||||
|
|
||||||
|
# Ensure completion after completion_threshold tokens
|
||||||
|
if bias_index + 1 > completion_threshold:
|
||||||
|
bias_score = 999
|
||||||
|
|
||||||
|
token_to_bias = token_seq[bias_index]
|
||||||
|
ret[token_to_bias] = bias_score
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
assert scores.ndim == 2
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
scores_shape = scores.shape
|
||||||
|
|
||||||
|
for batch in range(scores_shape[0]):
|
||||||
|
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
||||||
|
scores[batch][token] += bias
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class LuaLogitsProcessor(LogitsProcessor):
|
class LuaLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -1373,6 +1434,7 @@ def patch_transformers():
|
|||||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||||
processors.insert(0, LuaLogitsProcessor())
|
processors.insert(0, LuaLogitsProcessor())
|
||||||
|
processors.insert(0, PhraseBiasLogitsProcessor())
|
||||||
return processors
|
return processors
|
||||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||||
|
@@ -249,7 +249,7 @@ class model_settings(settings):
|
|||||||
self.selected_preset = ""
|
self.selected_preset = ""
|
||||||
self.uid_presets = []
|
self.uid_presets = []
|
||||||
self.default_preset = {}
|
self.default_preset = {}
|
||||||
self.biases = {} # should look like {"phrase": [percent, max_occurances]}
|
self.biases = {} # should look like {"phrase": [score, completion_threshold]}
|
||||||
|
|
||||||
#dummy class to eat the tqdm output
|
#dummy class to eat the tqdm output
|
||||||
class ignore_tqdm(object):
|
class ignore_tqdm(object):
|
||||||
|
@@ -363,12 +363,12 @@
|
|||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
.bias_percent {
|
.bias_score {
|
||||||
grid-area: percent;
|
grid-area: percent;
|
||||||
margin-right: 5px;
|
margin-right: 5px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.bias_max {
|
.bias_comp_threshold {
|
||||||
grid-area: max;
|
grid-area: max;
|
||||||
margin-right: 5px;
|
margin-right: 5px;
|
||||||
}
|
}
|
||||||
@@ -409,7 +409,7 @@
|
|||||||
font-size: small;
|
font-size: small;
|
||||||
}
|
}
|
||||||
|
|
||||||
.bias_header_percent {
|
.bias_header_score {
|
||||||
grid-area: percent;
|
grid-area: percent;
|
||||||
font-size: small;
|
font-size: small;
|
||||||
}
|
}
|
||||||
|
@@ -1252,16 +1252,16 @@ function save_bias(item) {
|
|||||||
//get all of our biases
|
//get all of our biases
|
||||||
for (bias of document.getElementsByClassName("bias")) {
|
for (bias of document.getElementsByClassName("bias")) {
|
||||||
//phrase
|
//phrase
|
||||||
phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
|
var phrase = bias.querySelector(".bias_phrase").querySelector("input").value;
|
||||||
|
|
||||||
//percent
|
//percent
|
||||||
percent = parseFloat(bias.querySelector(".bias_percent").querySelector("input").value);
|
var percent = parseFloat(bias.querySelector(".bias_score").querySelector("input").value);
|
||||||
|
|
||||||
//max occurance
|
//completion threshold
|
||||||
max_occurance = parseInt(bias.querySelector(".bias_max").querySelector("input").value);
|
var comp_threshold = parseInt(bias.querySelector(".bias_comp_threshold").querySelector("input").value);
|
||||||
|
|
||||||
if (phrase != "") {
|
if (phrase != "") {
|
||||||
biases[phrase] = [percent, max_occurance];
|
biases[phrase] = [percent, comp_threshold];
|
||||||
} else {
|
} else {
|
||||||
//mark that we have a blank line, or delete it if we have more than one
|
//mark that we have a blank line, or delete it if we have more than one
|
||||||
if (have_blank) {
|
if (have_blank) {
|
||||||
@@ -1276,8 +1276,8 @@ function save_bias(item) {
|
|||||||
console.log("Create new bias line");
|
console.log("Create new bias line");
|
||||||
bias_line = document.getElementsByClassName("bias")[0].cloneNode(true);
|
bias_line = document.getElementsByClassName("bias")[0].cloneNode(true);
|
||||||
bias_line.querySelector(".bias_phrase").querySelector("input").value = "";
|
bias_line.querySelector(".bias_phrase").querySelector("input").value = "";
|
||||||
bias_line.querySelector(".bias_percent").querySelector("input").value = 1;
|
bias_line.querySelector(".bias_score").querySelector("input").value = 0;
|
||||||
bias_line.querySelector(".bias_max").querySelector("input").value = 50;
|
bias_line.querySelector(".bias_comp_threshold").querySelector("input").value = 50;
|
||||||
document.getElementById('biasing').append(bias_line);
|
document.getElementById('biasing').append(bias_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -57,35 +57,35 @@
|
|||||||
<div id="biasing">
|
<div id="biasing">
|
||||||
<div class="bias_header">
|
<div class="bias_header">
|
||||||
<div class="bias_header_phrase">Phrase</div>
|
<div class="bias_header_phrase">Phrase</div>
|
||||||
<div class="bias_header_percent">Percent Chance</div>
|
<div class="bias_header_score">Score</div>
|
||||||
<div class="bias_header_max">Max Occurance</div>
|
<div class="bias_header_comp_threshold">Completion Threshold</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="bias">
|
<div class="bias">
|
||||||
<div class="bias_phrase">
|
<div class="bias_phrase">
|
||||||
<input type=text placeholder="Word or Phrase to Bias" onchange="save_bias(this);"/>
|
<input type=text placeholder="Word or Phrase to Bias" onchange="save_bias(this);"/>
|
||||||
</div>
|
</div>
|
||||||
<div class="bias_percent">
|
<div class="bias_score">
|
||||||
<div class="bias_slider">
|
<div class="bias_slider">
|
||||||
<div class="bias_slider_bar">
|
<div class="bias_slider_bar">
|
||||||
<input type="range" min="0" max="1" step="0.01" value="1" class="setting_item_input"
|
<input type="range" min="-12" max="12" step="0.01" value="0" class="setting_item_input"
|
||||||
oninput="update_bias_slider_value(this);"
|
oninput="update_bias_slider_value(this);"
|
||||||
onchange="save_bias(this);"/>
|
onchange="save_bias(this);"/>
|
||||||
</div>
|
</div>
|
||||||
<div class="bias_slider_min">0</div>
|
<div class="bias_slider_min">-12.00</div>
|
||||||
<div class="bias_slider_cur">1</div>
|
<div class="bias_slider_cur">0</div>
|
||||||
<div class="bias_slider_max">1</div>
|
<div class="bias_slider_max">12.00</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="bias_max">
|
<div class="bias_comp_threshold">
|
||||||
<div class="bias_slider">
|
<div class="bias_slider">
|
||||||
<div class="bias_slider_bar">
|
<div class="bias_slider_bar">
|
||||||
<input type="range" min="0" max="50" step="1" value="50" class="setting_item_input"
|
<input type="range" min="0" max="10" step="1" value="10" class="setting_item_input"
|
||||||
oninput="update_bias_slider_value(this);"
|
oninput="update_bias_slider_value(this);"
|
||||||
onchange="save_bias(this);"/>
|
onchange="save_bias(this);"/>
|
||||||
</div>
|
</div>
|
||||||
<div class="bias_slider_min">0</div>
|
<div class="bias_slider_min">0</div>
|
||||||
<div class="bias_slider_cur">50</div>
|
<div class="bias_slider_cur">10</div>
|
||||||
<div class="bias_slider_max">50</div>
|
<div class="bias_slider_max">10</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
Reference in New Issue
Block a user