New bias token control (and lack thereof)

Allows square bracket syntax for using token ids, curly bracket syntax
for strict phrasing, and normal now biases alternative phrases with
space prefixes
This commit is contained in:
somebody
2023-01-07 23:59:46 -06:00
parent 76c1398917
commit f6c4bfb390

View File

@@ -2218,28 +2218,90 @@ def patch_transformers():
# There were no matches, so just begin at the beginning.
return 0
def _allow_leftwards_tampering(self, phrase: str) -> bool:
"""Determines if a phrase should be tampered with from the left in
the "soft" token encoding mode."""
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
return False
return True
def _get_token_sequence(self, phrase: str) -> List[List]:
"""Convert the phrase string into a list of encoded biases, each
one being a list of tokens. How this is done is determined by the
phrase's format:
- If the phrase is surrounded by square brackets ([]), the tokens
will be the phrase split by commas (,). If a "token" isn't
actually a number, it will be skipped. NOTE: Tokens output by
this may not be in the model's vocabulary, and such tokens
should be ignored later in the pipeline.
- If the phrase is surrounded by curly brackets ({}), the phrase
will be directly encoded with no synonym biases and no fancy
tricks.
- Otherwise, the phrase will be encoded, with close deviations
being included as synonym biases.
"""
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
# Handle direct token id input
if phrase.startswith("[") and phrase.endswith("]"):
no_brackets = phrase[1:-1]
ret = []
for token_id in no_brackets.split(","):
try:
ret.append(int(token_id))
except ValueError:
# Ignore non-numbers. Rascals!
pass
return [ret]
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
ret = []
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(tokenizer.encode(alt_phrase))
return ret
def _get_biased_tokens(self, input_ids: List) -> Dict:
# TODO: Different "bias slopes"?
ret = {}
for phrase, _bias in koboldai_vars.biases.items():
bias_score, completion_threshold = _bias
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
token_seq = tokenizer.encode(phrase)
bias_index = self._find_intersection(input_ids, token_seq)
token_seqs = self._get_token_sequence(phrase)
variant_deltas = {}
for token_seq in token_seqs:
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
# If multiple phrases bias the same token, add the modifiers together.
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
token_to_bias = token_seq[bias_index]
variant_deltas[token_to_bias] = bias_score
# If multiple phrases bias the same token, add the modifiers
# together. This should NOT be applied to automatic variants
for token_to_bias, bias_score in variant_deltas.items():
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
return ret
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: