mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Samplers: Part 2
This commit is contained in:
718
model.py
718
model.py
@@ -1,5 +1,4 @@
|
||||
# TODO:
|
||||
# - Intertwine stoppers and streaming and such
|
||||
# Before merge: please make sure to fix any TODOB4MERGE comments
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
@@ -16,12 +15,15 @@ import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
import zipfile
|
||||
from tqdm.auto import tqdm
|
||||
from logger import logger
|
||||
import torch_lazy_loader
|
||||
|
||||
import warpers
|
||||
from warpers import Warper
|
||||
|
||||
import torch
|
||||
from torch.nn import Embedding
|
||||
import numpy as np
|
||||
@@ -32,14 +34,13 @@ from transformers import (
|
||||
GPT2Tokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
modeling_utils,
|
||||
AutoModelForTokenClassification,
|
||||
AutoConfig,
|
||||
LogitsProcessorList,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
import utils
|
||||
@@ -399,21 +400,6 @@ def patch_transformers_generation() -> None:
|
||||
global transformers
|
||||
|
||||
# Patch transformers to use our custom logit warpers -- Only HFTorchInferenceModel uses this
|
||||
from transformers import (
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TemperatureLogitsWarper,
|
||||
)
|
||||
from warpers import (
|
||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
TailFreeLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
TopALogitsWarper,
|
||||
)
|
||||
|
||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||
old_call = cls.__call__
|
||||
|
||||
@@ -434,343 +420,22 @@ def patch_transformers_generation() -> None:
|
||||
cls.__call__ = new_call
|
||||
|
||||
# TODO: Make samplers generic
|
||||
dynamic_processor_wrap(
|
||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||
("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
|
||||
cond=lambda x: x[0] != 1.0,
|
||||
)
|
||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
dynamic_processor_wrap(
|
||||
TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
|
||||
)
|
||||
dynamic_processor_wrap(
|
||||
TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
|
||||
)
|
||||
|
||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
"""Find the maximum overlap between the beginning of small and the end of big.
|
||||
Return the index of the token in small following the overlap, or 0.
|
||||
|
||||
big: The tokens in the context (as a tensor)
|
||||
small: The tokens in the phrase to bias (as a list)
|
||||
|
||||
Both big and small are in "oldest to newest" order.
|
||||
"""
|
||||
# There are asymptotically more efficient methods for determining the overlap,
|
||||
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
||||
# elements of big, plus small will typically be fairly short. So this naive
|
||||
# approach is acceptable despite O(N^2) worst case performance.
|
||||
|
||||
num_small = len(small)
|
||||
# The small list can only ever match against at most num_small tokens of big,
|
||||
# so create a slice. Typically, this slice will be as long as small, but it
|
||||
# may be shorter if the story has just started.
|
||||
# We need to convert the big slice to list, since natively big is a tensor
|
||||
# and tensor and list don't ever compare equal. It's better to convert here
|
||||
# and then use native equality tests than to iterate repeatedly later.
|
||||
big_slice = list(big[-num_small:])
|
||||
|
||||
# It's possible that the start token appears multiple times in small
|
||||
# For example, consider the phrase:
|
||||
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
||||
# If we merely look for the first instance of [ fair], then we would
|
||||
# generate the following output:
|
||||
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
||||
start = small[0]
|
||||
for i, t in enumerate(big_slice):
|
||||
# Strictly unnecessary, but it's marginally faster to test the first
|
||||
# token before creating slices to test for a full match.
|
||||
if t == start:
|
||||
remaining = len(big_slice) - i
|
||||
if big_slice[i:] == small[:remaining]:
|
||||
# We found a match. If the small phrase has any remaining tokens
|
||||
# then return the index of the next token.
|
||||
if remaining < num_small:
|
||||
return remaining
|
||||
# In this case, the entire small phrase matched, so start over.
|
||||
return 0
|
||||
|
||||
# There were no matches, so just begin at the beginning.
|
||||
return 0
|
||||
|
||||
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
||||
"""Determines if a phrase should be tampered with from the left in
|
||||
the "soft" token encoding mode."""
|
||||
|
||||
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_token_sequence(self, phrase: str) -> List[List]:
|
||||
"""Convert the phrase string into a list of encoded biases, each
|
||||
one being a list of tokens. How this is done is determined by the
|
||||
phrase's format:
|
||||
|
||||
- If the phrase is surrounded by square brackets ([]), the tokens
|
||||
will be the phrase split by commas (,). If a "token" isn't
|
||||
actually a number, it will be skipped. NOTE: Tokens output by
|
||||
this may not be in the model's vocabulary, and such tokens
|
||||
should be ignored later in the pipeline.
|
||||
- If the phrase is surrounded by curly brackets ({}), the phrase
|
||||
will be directly encoded with no synonym biases and no fancy
|
||||
tricks.
|
||||
- Otherwise, the phrase will be encoded, with close deviations
|
||||
being included as synonym biases.
|
||||
"""
|
||||
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
|
||||
# Handle direct token id input
|
||||
if phrase.startswith("[") and phrase.endswith("]"):
|
||||
no_brackets = phrase[1:-1]
|
||||
ret = []
|
||||
for token_id in no_brackets.split(","):
|
||||
try:
|
||||
ret.append(int(token_id))
|
||||
except ValueError:
|
||||
# Ignore non-numbers. Rascals!
|
||||
pass
|
||||
return [ret]
|
||||
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [HACK_currentmodel.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in utils.koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
token_seqs = self._get_token_sequence(phrase)
|
||||
variant_deltas = {}
|
||||
for token_seq in token_seqs:
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
# Only provide a positive bias when the base bias score is positive.
|
||||
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
variant_deltas[token_to_bias] = bias_score
|
||||
|
||||
# If multiple phrases bias the same token, add the modifiers
|
||||
# together. This should NOT be applied to automatic variants
|
||||
for token_to_bias, bias_score in variant_deltas.items():
|
||||
if token_to_bias in ret:
|
||||
ret[token_to_bias] += bias_score
|
||||
else:
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
if utils.koboldai_vars.standalone:
|
||||
return scores
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||
utils.koboldai_vars.lua_state.table()
|
||||
)
|
||||
for r, row in enumerate(scores_list):
|
||||
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||
r + 1
|
||||
] = utils.koboldai_vars.lua_state.table(*row)
|
||||
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
scores = torch.tensor(
|
||||
tuple(
|
||||
tuple(row.values())
|
||||
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||
),
|
||||
device=scores.device,
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
def visualize_probabilities(
|
||||
model: InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
assert scores.ndim == 2
|
||||
|
||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||
return
|
||||
|
||||
if not utils.koboldai_vars.show_probs:
|
||||
return scores
|
||||
|
||||
option_offset = 0
|
||||
if (
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
in utils.koboldai_vars.actions.actions
|
||||
):
|
||||
for x in range(
|
||||
len(
|
||||
utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"]
|
||||
)
|
||||
):
|
||||
option = utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"][x]
|
||||
if option["Pinned"] or option["Previous Selection"] or option["Edited"]:
|
||||
option_offset = x + 1
|
||||
batch_offset = (
|
||||
int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt)
|
||||
if utils.koboldai_vars.alt_multi_gen
|
||||
else 0
|
||||
)
|
||||
for batch_index, batch in enumerate(scores):
|
||||
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(
|
||||
enumerate(probs), key=lambda x: x[1], reverse=True
|
||||
)[:8]:
|
||||
token_prob_info.append(
|
||||
{
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(
|
||||
model.tokenizer.decode(token_id)
|
||||
),
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.numseqs == 1:
|
||||
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
||||
else:
|
||||
utils.koboldai_vars.actions.set_option_probabilities(
|
||||
token_prob_info, batch_index + option_offset + batch_offset
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
||||
processors.insert(0, LuaLogitsProcessor())
|
||||
processors.append(PhraseBiasLogitsProcessor())
|
||||
return processors
|
||||
|
||||
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||
new_get_logits_processor.old_get_logits_processor = (
|
||||
transformers.GenerationMixin._get_logits_processor
|
||||
)
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __init__(self, beams: int = 1, **kwargs):
|
||||
self.__warper_list: List[LogitsWarper] = []
|
||||
self.__warper_list.append(
|
||||
TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))
|
||||
)
|
||||
self.__warper_list.append(
|
||||
TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))
|
||||
)
|
||||
self.__warper_list.append(
|
||||
TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))
|
||||
)
|
||||
self.__warper_list.append(
|
||||
TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))
|
||||
)
|
||||
self.__warper_list.append(
|
||||
TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))
|
||||
)
|
||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||
if (
|
||||
len(sampler_order) < 7
|
||||
): # Add repetition penalty at beginning if it's not present
|
||||
sampler_order = [6] + sampler_order
|
||||
for k in sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
visualize_probabilities(HACK_currentmodel, scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
beams: int = 1,
|
||||
) -> LogitsProcessorList:
|
||||
return KoboldLogitsWarperList(beams=beams)
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
beams=1,
|
||||
)
|
||||
if (utils.koboldai_vars.newlinemode == "s") or (
|
||||
utils.koboldai_vars.newlinemode == "ns"
|
||||
):
|
||||
kwargs["eos_token_id"] = -1
|
||||
kwargs.setdefault("pad_token_id", 2)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
|
||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||
use_core_manipulations.sample = new_sample
|
||||
# dynamic_processor_wrap(
|
||||
# AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
|
||||
# cond=lambda x: x[0] != 1.0,
|
||||
# )
|
||||
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
# dynamic_processor_wrap(
|
||||
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
|
||||
# )
|
||||
# dynamic_processor_wrap(
|
||||
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
|
||||
# )
|
||||
|
||||
# Allow bad words filter to ban <|endoftext|> token
|
||||
import transformers.generation.logits_process
|
||||
@@ -852,10 +517,12 @@ class InferenceModel:
|
||||
global HACK_currentmodel
|
||||
HACK_currentmodel = self
|
||||
|
||||
print(self.raw_generate("Hi guys,", 20).__dict__)
|
||||
|
||||
def _post_load(self) -> None:
|
||||
pass
|
||||
|
||||
def _load(self, save_model: bool, inital_load: bool) -> None:
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_tokenizer(self, location: str):
|
||||
@@ -1520,8 +1187,22 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
)
|
||||
self._old_stopping_criteria = None
|
||||
|
||||
def _apply_warpers(
|
||||
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
warpers.update_settings()
|
||||
for sid in utils.koboldai_vars.sampler_order:
|
||||
warper = Warper.from_id(sid)
|
||||
if warper == warpers.RepetitionPenalty:
|
||||
# Rep pen needs more data than other samplers
|
||||
print("is rep:", warper)
|
||||
scores = warper.torch(scores, input_ids=input_ids)
|
||||
else:
|
||||
print("aint rep:", warper)
|
||||
scores = warper.torch(scores)
|
||||
return scores
|
||||
|
||||
def _post_load(self) -> None:
|
||||
print("HELLLOOOOOOOOOOOOOOOOOOOOOOOOOOO")
|
||||
# Patch stopping_criteria
|
||||
|
||||
class PTHStopper(StoppingCriteria):
|
||||
@@ -1551,6 +1232,323 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
|
||||
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
||||
|
||||
# Patch logitswarpers
|
||||
|
||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
"""Find the maximum overlap between the beginning of small and the end of big.
|
||||
Return the index of the token in small following the overlap, or 0.
|
||||
|
||||
big: The tokens in the context (as a tensor)
|
||||
small: The tokens in the phrase to bias (as a list)
|
||||
|
||||
Both big and small are in "oldest to newest" order.
|
||||
"""
|
||||
# There are asymptotically more efficient methods for determining the overlap,
|
||||
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
||||
# elements of big, plus small will typically be fairly short. So this naive
|
||||
# approach is acceptable despite O(N^2) worst case performance.
|
||||
|
||||
num_small = len(small)
|
||||
# The small list can only ever match against at most num_small tokens of big,
|
||||
# so create a slice. Typically, this slice will be as long as small, but it
|
||||
# may be shorter if the story has just started.
|
||||
# We need to convert the big slice to list, since natively big is a tensor
|
||||
# and tensor and list don't ever compare equal. It's better to convert here
|
||||
# and then use native equality tests than to iterate repeatedly later.
|
||||
big_slice = list(big[-num_small:])
|
||||
|
||||
# It's possible that the start token appears multiple times in small
|
||||
# For example, consider the phrase:
|
||||
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
||||
# If we merely look for the first instance of [ fair], then we would
|
||||
# generate the following output:
|
||||
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
||||
start = small[0]
|
||||
for i, t in enumerate(big_slice):
|
||||
# Strictly unnecessary, but it's marginally faster to test the first
|
||||
# token before creating slices to test for a full match.
|
||||
if t == start:
|
||||
remaining = len(big_slice) - i
|
||||
if big_slice[i:] == small[:remaining]:
|
||||
# We found a match. If the small phrase has any remaining tokens
|
||||
# then return the index of the next token.
|
||||
if remaining < num_small:
|
||||
return remaining
|
||||
# In this case, the entire small phrase matched, so start over.
|
||||
return 0
|
||||
|
||||
# There were no matches, so just begin at the beginning.
|
||||
return 0
|
||||
|
||||
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
||||
"""Determines if a phrase should be tampered with from the left in
|
||||
the "soft" token encoding mode."""
|
||||
|
||||
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_token_sequence(self, phrase: str) -> List[List]:
|
||||
"""Convert the phrase string into a list of encoded biases, each
|
||||
one being a list of tokens. How this is done is determined by the
|
||||
phrase's format:
|
||||
|
||||
- If the phrase is surrounded by square brackets ([]), the tokens
|
||||
will be the phrase split by commas (,). If a "token" isn't
|
||||
actually a number, it will be skipped. NOTE: Tokens output by
|
||||
this may not be in the model's vocabulary, and such tokens
|
||||
should be ignored later in the pipeline.
|
||||
- If the phrase is surrounded by curly brackets ({}), the phrase
|
||||
will be directly encoded with no synonym biases and no fancy
|
||||
tricks.
|
||||
- Otherwise, the phrase will be encoded, with close deviations
|
||||
being included as synonym biases.
|
||||
"""
|
||||
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
|
||||
# Handle direct token id input
|
||||
if phrase.startswith("[") and phrase.endswith("]"):
|
||||
no_brackets = phrase[1:-1]
|
||||
ret = []
|
||||
for token_id in no_brackets.split(","):
|
||||
try:
|
||||
ret.append(int(token_id))
|
||||
except ValueError:
|
||||
# Ignore non-numbers. Rascals!
|
||||
pass
|
||||
return [ret]
|
||||
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [HACK_currentmodel.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in utils.koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
token_seqs = self._get_token_sequence(phrase)
|
||||
variant_deltas = {}
|
||||
for token_seq in token_seqs:
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
# Only provide a positive bias when the base bias score is positive.
|
||||
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
variant_deltas[token_to_bias] = bias_score
|
||||
|
||||
# If multiple phrases bias the same token, add the modifiers
|
||||
# together. This should NOT be applied to automatic variants
|
||||
for token_to_bias, bias_score in variant_deltas.items():
|
||||
if token_to_bias in ret:
|
||||
ret[token_to_bias] += bias_score
|
||||
else:
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(
|
||||
input_ids[batch]
|
||||
).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
if utils.koboldai_vars.standalone:
|
||||
return scores
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||
utils.koboldai_vars.lua_state.table()
|
||||
)
|
||||
for r, row in enumerate(scores_list):
|
||||
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||
r + 1
|
||||
] = utils.koboldai_vars.lua_state.table(*row)
|
||||
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
scores = torch.tensor(
|
||||
tuple(
|
||||
tuple(row.values())
|
||||
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||
),
|
||||
device=scores.device,
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
def visualize_probabilities(
|
||||
model: InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
assert scores.ndim == 2
|
||||
|
||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||
return
|
||||
|
||||
if not utils.koboldai_vars.show_probs:
|
||||
return scores
|
||||
|
||||
option_offset = 0
|
||||
if (
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
in utils.koboldai_vars.actions.actions
|
||||
):
|
||||
for x in range(
|
||||
len(
|
||||
utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"]
|
||||
)
|
||||
):
|
||||
option = utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"][x]
|
||||
if (
|
||||
option["Pinned"]
|
||||
or option["Previous Selection"]
|
||||
or option["Edited"]
|
||||
):
|
||||
option_offset = x + 1
|
||||
batch_offset = (
|
||||
int(
|
||||
(utils.koboldai_vars.generated_tkns - 1)
|
||||
/ utils.koboldai_vars.genamt
|
||||
)
|
||||
if utils.koboldai_vars.alt_multi_gen
|
||||
else 0
|
||||
)
|
||||
for batch_index, batch in enumerate(scores):
|
||||
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(
|
||||
enumerate(probs), key=lambda x: x[1], reverse=True
|
||||
)[:8]:
|
||||
token_prob_info.append(
|
||||
{
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(
|
||||
model.tokenizer.decode(token_id)
|
||||
),
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.numseqs == 1:
|
||||
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
||||
else:
|
||||
utils.koboldai_vars.actions.set_option_probabilities(
|
||||
token_prob_info, batch_index + option_offset + batch_offset
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(
|
||||
*args, **kwargs
|
||||
)
|
||||
# TODOB4MERGE: These two
|
||||
# processors.insert(0, LuaLogitsProcessor())
|
||||
# processors.append(PhraseBiasLogitsProcessor())
|
||||
return processors
|
||||
|
||||
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||
new_get_logits_processor.old_get_logits_processor = (
|
||||
transformers.GenerationMixin._get_logits_processor
|
||||
)
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
lw_self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||
# if (
|
||||
# len(sampler_order) < 7
|
||||
# ): # Add repetition penalty at beginning if it's not present
|
||||
# sampler_order = [6] + sampler_order
|
||||
# for k in sampler_order:
|
||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(HACK_currentmodel, scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
beams: int = 1,
|
||||
) -> LogitsProcessorList:
|
||||
return KoboldLogitsWarperList()
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
beams=1,
|
||||
)
|
||||
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
|
||||
kwargs["eos_token_id"] = -1
|
||||
kwargs.setdefault("pad_token_id", 2)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
|
||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||
use_core_manipulations.sample = new_sample
|
||||
|
||||
def _raw_generate(
|
||||
self,
|
||||
prompt_tokens: Union[List[int], torch.Tensor],
|
||||
|
@@ -349,57 +349,6 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
||||
# probability distribution)
|
||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||
|
||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
'''
|
||||
rpslope = jnp.int32(rpslope)
|
||||
rprange = jnp.int32(rprange)
|
||||
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
|
||||
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
||||
# Make a new array with the same length as the tokens array but with
|
||||
# each element replaced by the value at the corresponding index in the
|
||||
# logits array; e.g.
|
||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||
penalty_logits = jnp.take(logits, tokens)
|
||||
# Repetition penalty slope
|
||||
def apply_slope(carry):
|
||||
repetition_penalty, rprange = carry
|
||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
||||
return _penalty
|
||||
repetition_penalty = jax.lax.cond(
|
||||
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
||||
apply_slope,
|
||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
||||
(repetition_penalty, rprange),
|
||||
)
|
||||
# Divide positive values by repetition_penalty and multiply negative
|
||||
# values by repetition_penalty (the academic publication that described
|
||||
# this technique actually just only divided, but that would cause tokens
|
||||
# with negative logits to become more likely, which is obviously wrong)
|
||||
if koboldai_vars.use_alt_rep_pen:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
penalty_logits - jnp.log(repetition_penalty),
|
||||
penalty_logits,
|
||||
)
|
||||
else:
|
||||
penalty_logits = jnp.where(
|
||||
penalty_arange >= 0,
|
||||
jnp.where(
|
||||
penalty_logits > 0,
|
||||
penalty_logits/repetition_penalty,
|
||||
penalty_logits*repetition_penalty,
|
||||
),
|
||||
penalty_logits,
|
||||
)
|
||||
# Finally, put those penalized logit values back into their original
|
||||
# positions in the logits array
|
||||
return logits.at[tokens].set(penalty_logits)
|
||||
|
||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||
'''
|
||||
|
34
warpers.py
34
warpers.py
@@ -38,11 +38,31 @@ file are mostly porting warper code to the torch methods.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import utils
|
||||
import torch
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tpu_mtj_backend
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tpu_mtj_backend
|
||||
except ImportError:
|
||||
assert not utils.koboldai_vars.use_colab_tpu
|
||||
|
||||
|
||||
def update_settings():
|
||||
# This feels like a bad way to structure this
|
||||
koboldai_vars = utils.koboldai_vars
|
||||
Temperature.temperature = koboldai_vars.temp
|
||||
TopP.top_p = koboldai_vars.top_p
|
||||
TopK.top_k = koboldai_vars.top_k
|
||||
TopA.top_a = koboldai_vars.top_a
|
||||
TailFree.tfs = koboldai_vars.tfs
|
||||
Typical.typical = koboldai_vars.typical
|
||||
RepetitionPenalty.rep_pen = koboldai_vars.rep_pen
|
||||
RepetitionPenalty.rep_pen_range = koboldai_vars.rep_pen_range
|
||||
RepetitionPenalty.rep_pen_slope = koboldai_vars.rep_pen_slope
|
||||
RepetitionPenalty.use_alt_rep_pen = koboldai_vars.use_alt_rep_pen
|
||||
|
||||
|
||||
class Warper:
|
||||
@@ -70,7 +90,7 @@ class Temperature(Warper):
|
||||
|
||||
@classmethod
|
||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||
return scores / cls.value
|
||||
return scores / cls.temperature
|
||||
|
||||
|
||||
class TopP(Warper):
|
||||
@@ -88,7 +108,7 @@ class TopP(Warper):
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - cls.top_p)
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
@@ -109,7 +129,7 @@ class TopP(Warper):
|
||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||
# We want to remove tokens with cumulative probability higher
|
||||
# than top_p
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.value
|
||||
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
||||
# Don't ever remove the token with the highest logit, even if
|
||||
# the probability is higher than top_p
|
||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||
@@ -358,7 +378,7 @@ class RepetitionPenalty(Warper):
|
||||
use_alt_rep_pen: bool = False
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
def torch(cls, scores: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
cls.rep_pen_range = int(cls.rep_pen_range)
|
||||
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
|
||||
|
||||
|
Reference in New Issue
Block a user