Merge pull request #251 from one-some/bias-improvements

Bias improvements
This commit is contained in:
henk717
2023-01-08 17:45:41 +01:00
committed by GitHub
6 changed files with 329 additions and 178 deletions

View File

@@ -40,7 +40,6 @@ import packaging
import packaging.version
import contextlib
import traceback
import threading
import markdown
import bleach
import itertools
@@ -63,7 +62,6 @@ import sys
import gc
import lupa
import importlib
# KoboldAI
import fileops
@@ -83,7 +81,7 @@ import transformers.generation_utils
# Text2img
import base64
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, PngImagePlugin
from PIL import Image
from io import BytesIO
global tpu_mtj_backend
@@ -2220,28 +2218,90 @@ def patch_transformers():
# There were no matches, so just begin at the beginning.
return 0
def _allow_leftwards_tampering(self, phrase: str) -> bool:
"""Determines if a phrase should be tampered with from the left in
the "soft" token encoding mode."""
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
return False
return True
def _get_token_sequence(self, phrase: str) -> List[List]:
"""Convert the phrase string into a list of encoded biases, each
one being a list of tokens. How this is done is determined by the
phrase's format:
- If the phrase is surrounded by square brackets ([]), the tokens
will be the phrase split by commas (,). If a "token" isn't
actually a number, it will be skipped. NOTE: Tokens output by
this may not be in the model's vocabulary, and such tokens
should be ignored later in the pipeline.
- If the phrase is surrounded by curly brackets ({}), the phrase
will be directly encoded with no synonym biases and no fancy
tricks.
- Otherwise, the phrase will be encoded, with close deviations
being included as synonym biases.
"""
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
# Handle direct token id input
if phrase.startswith("[") and phrase.endswith("]"):
no_brackets = phrase[1:-1]
ret = []
for token_id in no_brackets.split(","):
try:
ret.append(int(token_id))
except ValueError:
# Ignore non-numbers. Rascals!
pass
return [ret]
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
ret = []
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(tokenizer.encode(alt_phrase))
return ret
def _get_biased_tokens(self, input_ids: List) -> Dict:
# TODO: Different "bias slopes"?
ret = {}
for phrase, _bias in koboldai_vars.biases.items():
bias_score, completion_threshold = _bias
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
token_seq = tokenizer.encode(phrase)
bias_index = self._find_intersection(input_ids, token_seq)
token_seqs = self._get_token_sequence(phrase)
variant_deltas = {}
for token_seq in token_seqs:
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
# If multiple phrases bias the same token, add the modifiers together.
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
token_to_bias = token_seq[bias_index]
variant_deltas[token_to_bias] = bias_score
# If multiple phrases bias the same token, add the modifiers
# together. This should NOT be applied to automatic variants
for token_to_bias, bias_score in variant_deltas.items():
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
return ret
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -7847,7 +7907,6 @@ def final_startup():
return
utils.decodenewlines(tokenizer.decode([25678, 559]))
tokenizer.encode(utils.encodenewlines("eunoia"))
#threading.Thread(target=__preempt_tokenizer).start()
tpool.execute(__preempt_tokenizer)
# Load soft prompt specified by the settings file, if applicable
@@ -7865,18 +7924,6 @@ def final_startup():
if(koboldai_vars.use_colab_tpu or koboldai_vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
soft_tokens = tpumtjgetsofttokens()
if(koboldai_vars.dynamicscan or (not koboldai_vars.nogenmod and koboldai_vars.has_genmod)):
#threading.Thread(
# target=tpu_mtj_backend.infer_dynamic,
# args=(np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)),),
# kwargs={
# "soft_embeddings": koboldai_vars.sp,
# "soft_tokens": soft_tokens,
# "gen_len": 1,
# "use_callback": False,
# "numseqs": koboldai_vars.numseqs,
# "excluded_world_info": list(set() for _ in range(koboldai_vars.numseqs)),
# },
#).start()
tpool.execute(tpu_mtj_backend.infer_dynamic, np.tile(np.uint32((23403, 727, 20185)), (koboldai_vars.numseqs, 1)),
soft_embeddings= koboldai_vars.sp,
soft_tokens= soft_tokens,
@@ -7886,16 +7933,6 @@ def final_startup():
excluded_world_info= list(set() for _ in range(koboldai_vars.numseqs))
)
else:
#threading.Thread(
# target=tpu_mtj_backend.infer_static,
# args=(np.uint32((23403, 727, 20185)),),
# kwargs={
# "soft_embeddings": koboldai_vars.sp,
# "soft_tokens": soft_tokens,
# "gen_len": 1,
# "numseqs": koboldai_vars.numseqs,
# },
#).start()
tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32((23403, 727, 20185)),