mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add phrase biasing backend
This commit is contained in:
64
aiserver.py
64
aiserver.py
@@ -39,7 +39,7 @@ import functools
|
||||
import traceback
|
||||
from collections.abc import Iterable
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
from typing import Any, Callable, Optional, TypeVar, Tuple, Union, Dict, Set, List
|
||||
|
||||
import requests
|
||||
import html
|
||||
@@ -1341,6 +1341,67 @@ def patch_transformers():
|
||||
RepetitionPenaltyLogitsProcessor.__init__ = AdvancedRepetitionPenaltyLogitsProcessor.__init__
|
||||
RepetitionPenaltyLogitsProcessor.__call__ = AdvancedRepetitionPenaltyLogitsProcessor.__call__
|
||||
|
||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _rindex(self, lst: List, target) -> Optional[int]:
|
||||
for index, item in enumerate(reversed(lst)):
|
||||
if item == target:
|
||||
return len(lst) - index - 1
|
||||
return None
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
# Find the intersection of the end of "big" and the beginning of
|
||||
# "small". A headache to think about, personally. Returns the index
|
||||
# into "small" where the two stop intersecting.
|
||||
start = self._rindex(big, small[0])
|
||||
|
||||
# No progress into the token sequence, bias the first one.
|
||||
if not start:
|
||||
return 0
|
||||
|
||||
for i in range(len(small)):
|
||||
try:
|
||||
big_i = big[start + i]
|
||||
except IndexError:
|
||||
return i
|
||||
|
||||
# It's completed :^)
|
||||
return 0
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
token_seq = tokenizer.encode(phrase)
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
if bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self):
|
||||
@@ -1373,6 +1434,7 @@ def patch_transformers():
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
processors.insert(0, PhraseBiasLogitsProcessor())
|
||||
return processors
|
||||
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
|
||||
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
|
||||
|
Reference in New Issue
Block a user