Samplers: Part 2

This commit is contained in:
somebody
2023-02-26 17:22:54 -06:00
parent f882979c88
commit af73527be0
3 changed files with 385 additions and 418 deletions

718
model.py
View File

@@ -1,5 +1,4 @@
# TODO:
# - Intertwine stoppers and streaming and such
# Before merge: please make sure to fix any TODOB4MERGE comments
from __future__ import annotations
import bisect
@@ -16,12 +15,15 @@ import json
import os
import time
import traceback
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union
import zipfile
from tqdm.auto import tqdm
from logger import logger
import torch_lazy_loader
import warpers
from warpers import Warper
import torch
from torch.nn import Embedding
import numpy as np
@@ -32,14 +34,13 @@ from transformers import (
GPT2Tokenizer,
GPT2LMHeadModel,
GPTNeoForCausalLM,
GPTNeoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedModel,
modeling_utils,
AutoModelForTokenClassification,
AutoConfig,
LogitsProcessorList,
LogitsProcessor,
)
import utils
@@ -399,21 +400,6 @@ def patch_transformers_generation() -> None:
global transformers
# Patch transformers to use our custom logit warpers -- Only HFTorchInferenceModel uses this
from transformers import (
LogitsProcessorList,
LogitsWarper,
LogitsProcessor,
TopKLogitsWarper,
TopPLogitsWarper,
TemperatureLogitsWarper,
)
from warpers import (
AdvancedRepetitionPenaltyLogitsProcessor,
TailFreeLogitsWarper,
TypicalLogitsWarper,
TopALogitsWarper,
)
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
old_call = cls.__call__
@@ -434,343 +420,22 @@ def patch_transformers_generation() -> None:
cls.__call__ = new_call
# TODO: Make samplers generic
dynamic_processor_wrap(
AdvancedRepetitionPenaltyLogitsProcessor,
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
cond=lambda x: x[0] != 1.0,
)
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
dynamic_processor_wrap(
TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
)
dynamic_processor_wrap(
TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
)
class PhraseBiasLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def _find_intersection(self, big: List, small: List) -> int:
"""Find the maximum overlap between the beginning of small and the end of big.
Return the index of the token in small following the overlap, or 0.
big: The tokens in the context (as a tensor)
small: The tokens in the phrase to bias (as a list)
Both big and small are in "oldest to newest" order.
"""
# There are asymptotically more efficient methods for determining the overlap,
# but typically there will be few (0-1) instances of small[0] in the last len(small)
# elements of big, plus small will typically be fairly short. So this naive
# approach is acceptable despite O(N^2) worst case performance.
num_small = len(small)
# The small list can only ever match against at most num_small tokens of big,
# so create a slice. Typically, this slice will be as long as small, but it
# may be shorter if the story has just started.
# We need to convert the big slice to list, since natively big is a tensor
# and tensor and list don't ever compare equal. It's better to convert here
# and then use native equality tests than to iterate repeatedly later.
big_slice = list(big[-num_small:])
# It's possible that the start token appears multiple times in small
# For example, consider the phrase:
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
# If we merely look for the first instance of [ fair], then we would
# generate the following output:
# " fair is foul, and foul is fair is foul, and foul is fair..."
start = small[0]
for i, t in enumerate(big_slice):
# Strictly unnecessary, but it's marginally faster to test the first
# token before creating slices to test for a full match.
if t == start:
remaining = len(big_slice) - i
if big_slice[i:] == small[:remaining]:
# We found a match. If the small phrase has any remaining tokens
# then return the index of the next token.
if remaining < num_small:
return remaining
# In this case, the entire small phrase matched, so start over.
return 0
# There were no matches, so just begin at the beginning.
return 0
def _allow_leftwards_tampering(self, phrase: str) -> bool:
"""Determines if a phrase should be tampered with from the left in
the "soft" token encoding mode."""
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
return False
return True
def _get_token_sequence(self, phrase: str) -> List[List]:
"""Convert the phrase string into a list of encoded biases, each
one being a list of tokens. How this is done is determined by the
phrase's format:
- If the phrase is surrounded by square brackets ([]), the tokens
will be the phrase split by commas (,). If a "token" isn't
actually a number, it will be skipped. NOTE: Tokens output by
this may not be in the model's vocabulary, and such tokens
should be ignored later in the pipeline.
- If the phrase is surrounded by curly brackets ({}), the phrase
will be directly encoded with no synonym biases and no fancy
tricks.
- Otherwise, the phrase will be encoded, with close deviations
being included as synonym biases.
"""
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
# Handle direct token id input
if phrase.startswith("[") and phrase.endswith("]"):
no_brackets = phrase[1:-1]
ret = []
for token_id in no_brackets.split(","):
try:
ret.append(int(token_id))
except ValueError:
# Ignore non-numbers. Rascals!
pass
return [ret]
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [HACK_currentmodel.tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
ret = []
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
return ret
def _get_biased_tokens(self, input_ids: List) -> Dict:
# TODO: Different "bias slopes"?
ret = {}
for phrase, _bias in utils.koboldai_vars.biases.items():
bias_score, completion_threshold = _bias
token_seqs = self._get_token_sequence(phrase)
variant_deltas = {}
for token_seq in token_seqs:
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
variant_deltas[token_to_bias] = bias_score
# If multiple phrases bias the same token, add the modifiers
# together. This should NOT be applied to automatic variants
for token_to_bias, bias_score in variant_deltas.items():
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
return ret
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
scores_shape = scores.shape
for batch in range(scores_shape[0]):
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
scores[batch][token] += bias
return scores
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
self.regeneration_required = False
self.halt = False
if utils.koboldai_vars.standalone:
return scores
scores_shape = scores.shape
scores_list = scores.tolist()
utils.koboldai_vars.lua_koboldbridge.logits = (
utils.koboldai_vars.lua_state.table()
)
for r, row in enumerate(scores_list):
utils.koboldai_vars.lua_koboldbridge.logits[
r + 1
] = utils.koboldai_vars.lua_state.table(*row)
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
scores = torch.tensor(
tuple(
tuple(row.values())
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
),
device=scores.device,
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
from torch.nn import functional as F
def visualize_probabilities(
model: InferenceModel,
scores: torch.FloatTensor,
) -> None:
assert scores.ndim == 2
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
return
if not utils.koboldai_vars.show_probs:
return scores
option_offset = 0
if (
utils.koboldai_vars.actions.action_count + 1
in utils.koboldai_vars.actions.actions
):
for x in range(
len(
utils.koboldai_vars.actions.actions[
utils.koboldai_vars.actions.action_count + 1
]["Options"]
)
):
option = utils.koboldai_vars.actions.actions[
utils.koboldai_vars.actions.action_count + 1
]["Options"][x]
if option["Pinned"] or option["Previous Selection"] or option["Edited"]:
option_offset = x + 1
batch_offset = (
int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt)
if utils.koboldai_vars.alt_multi_gen
else 0
)
for batch_index, batch in enumerate(scores):
probs = F.softmax(batch, dim=-1).cpu().numpy()
token_prob_info = []
for token_id, score in sorted(
enumerate(probs), key=lambda x: x[1], reverse=True
)[:8]:
token_prob_info.append(
{
"tokenId": token_id,
"decoded": utils.decodenewlines(
model.tokenizer.decode(token_id)
),
"score": float(score),
}
)
if utils.koboldai_vars.numseqs == 1:
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
else:
utils.koboldai_vars.actions.set_option_probabilities(
token_prob_info, batch_index + option_offset + batch_offset
)
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(PhraseBiasLogitsProcessor())
return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor
)
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self, beams: int = 1, **kwargs):
self.__warper_list: List[LogitsWarper] = []
self.__warper_list.append(
TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))
)
self.__warper_list.append(
TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))
)
self.__warper_list.append(
TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))
)
self.__warper_list.append(
TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))
)
self.__warper_list.append(
TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))
)
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
*args,
**kwargs,
):
sampler_order = utils.koboldai_vars.sampler_order[:]
if (
len(sampler_order) < 7
): # Add repetition penalty at beginning if it's not present
sampler_order = [6] + sampler_order
for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(HACK_currentmodel, scores)
return scores
def new_get_logits_warper(
beams: int = 1,
) -> LogitsProcessorList:
return KoboldLogitsWarperList(beams=beams)
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper(
beams=1,
)
if (utils.koboldai_vars.newlinemode == "s") or (
utils.koboldai_vars.newlinemode == "ns"
):
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
# dynamic_processor_wrap(
# AdvancedRepetitionPenaltyLogitsProcessor,
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
# cond=lambda x: x[0] != 1.0,
# )
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
# )
# dynamic_processor_wrap(
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
# )
# Allow bad words filter to ban <|endoftext|> token
import transformers.generation.logits_process
@@ -852,10 +517,12 @@ class InferenceModel:
global HACK_currentmodel
HACK_currentmodel = self
print(self.raw_generate("Hi guys,", 20).__dict__)
def _post_load(self) -> None:
pass
def _load(self, save_model: bool, inital_load: bool) -> None:
def _load(self, save_model: bool, initial_load: bool) -> None:
raise NotImplementedError
def _get_tokenizer(self, location: str):
@@ -1520,8 +1187,22 @@ class HFTorchInferenceModel(InferenceModel):
)
self._old_stopping_criteria = None
def _apply_warpers(
self, scores: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor:
warpers.update_settings()
for sid in utils.koboldai_vars.sampler_order:
warper = Warper.from_id(sid)
if warper == warpers.RepetitionPenalty:
# Rep pen needs more data than other samplers
print("is rep:", warper)
scores = warper.torch(scores, input_ids=input_ids)
else:
print("aint rep:", warper)
scores = warper.torch(scores)
return scores
def _post_load(self) -> None:
print("HELLLOOOOOOOOOOOOOOOOOOOOOOOOOOO")
# Patch stopping_criteria
class PTHStopper(StoppingCriteria):
@@ -1551,6 +1232,323 @@ class HFTorchInferenceModel(InferenceModel):
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
# Patch logitswarpers
class PhraseBiasLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def _find_intersection(self, big: List, small: List) -> int:
"""Find the maximum overlap between the beginning of small and the end of big.
Return the index of the token in small following the overlap, or 0.
big: The tokens in the context (as a tensor)
small: The tokens in the phrase to bias (as a list)
Both big and small are in "oldest to newest" order.
"""
# There are asymptotically more efficient methods for determining the overlap,
# but typically there will be few (0-1) instances of small[0] in the last len(small)
# elements of big, plus small will typically be fairly short. So this naive
# approach is acceptable despite O(N^2) worst case performance.
num_small = len(small)
# The small list can only ever match against at most num_small tokens of big,
# so create a slice. Typically, this slice will be as long as small, but it
# may be shorter if the story has just started.
# We need to convert the big slice to list, since natively big is a tensor
# and tensor and list don't ever compare equal. It's better to convert here
# and then use native equality tests than to iterate repeatedly later.
big_slice = list(big[-num_small:])
# It's possible that the start token appears multiple times in small
# For example, consider the phrase:
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
# If we merely look for the first instance of [ fair], then we would
# generate the following output:
# " fair is foul, and foul is fair is foul, and foul is fair..."
start = small[0]
for i, t in enumerate(big_slice):
# Strictly unnecessary, but it's marginally faster to test the first
# token before creating slices to test for a full match.
if t == start:
remaining = len(big_slice) - i
if big_slice[i:] == small[:remaining]:
# We found a match. If the small phrase has any remaining tokens
# then return the index of the next token.
if remaining < num_small:
return remaining
# In this case, the entire small phrase matched, so start over.
return 0
# There were no matches, so just begin at the beginning.
return 0
def _allow_leftwards_tampering(self, phrase: str) -> bool:
"""Determines if a phrase should be tampered with from the left in
the "soft" token encoding mode."""
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
return False
return True
def _get_token_sequence(self, phrase: str) -> List[List]:
"""Convert the phrase string into a list of encoded biases, each
one being a list of tokens. How this is done is determined by the
phrase's format:
- If the phrase is surrounded by square brackets ([]), the tokens
will be the phrase split by commas (,). If a "token" isn't
actually a number, it will be skipped. NOTE: Tokens output by
this may not be in the model's vocabulary, and such tokens
should be ignored later in the pipeline.
- If the phrase is surrounded by curly brackets ({}), the phrase
will be directly encoded with no synonym biases and no fancy
tricks.
- Otherwise, the phrase will be encoded, with close deviations
being included as synonym biases.
"""
# TODO: Cache these tokens, invalidate when model or bias is
# changed.
# Handle direct token id input
if phrase.startswith("[") and phrase.endswith("]"):
no_brackets = phrase[1:-1]
ret = []
for token_id in no_brackets.split(","):
try:
ret.append(int(token_id))
except ValueError:
# Ignore non-numbers. Rascals!
pass
return [ret]
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [HACK_currentmodel.tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
ret = []
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
return ret
def _get_biased_tokens(self, input_ids: List) -> Dict:
# TODO: Different "bias slopes"?
ret = {}
for phrase, _bias in utils.koboldai_vars.biases.items():
bias_score, completion_threshold = _bias
token_seqs = self._get_token_sequence(phrase)
variant_deltas = {}
for token_seq in token_seqs:
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
variant_deltas[token_to_bias] = bias_score
# If multiple phrases bias the same token, add the modifiers
# together. This should NOT be applied to automatic variants
for token_to_bias, bias_score in variant_deltas.items():
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
return ret
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
scores_shape = scores.shape
for batch in range(scores_shape[0]):
for token, bias in self._get_biased_tokens(
input_ids[batch]
).items():
scores[batch][token] += bias
return scores
class LuaLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert scores.ndim == 2
assert input_ids.ndim == 2
self.regeneration_required = False
self.halt = False
if utils.koboldai_vars.standalone:
return scores
scores_shape = scores.shape
scores_list = scores.tolist()
utils.koboldai_vars.lua_koboldbridge.logits = (
utils.koboldai_vars.lua_state.table()
)
for r, row in enumerate(scores_list):
utils.koboldai_vars.lua_koboldbridge.logits[
r + 1
] = utils.koboldai_vars.lua_state.table(*row)
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
scores = torch.tensor(
tuple(
tuple(row.values())
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
),
device=scores.device,
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
from torch.nn import functional as F
def visualize_probabilities(
model: InferenceModel,
scores: torch.FloatTensor,
) -> None:
assert scores.ndim == 2
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
return
if not utils.koboldai_vars.show_probs:
return scores
option_offset = 0
if (
utils.koboldai_vars.actions.action_count + 1
in utils.koboldai_vars.actions.actions
):
for x in range(
len(
utils.koboldai_vars.actions.actions[
utils.koboldai_vars.actions.action_count + 1
]["Options"]
)
):
option = utils.koboldai_vars.actions.actions[
utils.koboldai_vars.actions.action_count + 1
]["Options"][x]
if (
option["Pinned"]
or option["Previous Selection"]
or option["Edited"]
):
option_offset = x + 1
batch_offset = (
int(
(utils.koboldai_vars.generated_tkns - 1)
/ utils.koboldai_vars.genamt
)
if utils.koboldai_vars.alt_multi_gen
else 0
)
for batch_index, batch in enumerate(scores):
probs = F.softmax(batch, dim=-1).cpu().numpy()
token_prob_info = []
for token_id, score in sorted(
enumerate(probs), key=lambda x: x[1], reverse=True
)[:8]:
token_prob_info.append(
{
"tokenId": token_id,
"decoded": utils.decodenewlines(
model.tokenizer.decode(token_id)
),
"score": float(score),
}
)
if utils.koboldai_vars.numseqs == 1:
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
else:
utils.koboldai_vars.actions.set_option_probabilities(
token_prob_info, batch_index + option_offset + batch_offset
)
return scores
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(
*args, **kwargs
)
# TODOB4MERGE: These two
# processors.insert(0, LuaLogitsProcessor())
# processors.append(PhraseBiasLogitsProcessor())
return processors
use_core_manipulations.get_logits_processor = new_get_logits_processor
new_get_logits_processor.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor
)
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self):
pass
def __call__(
lw_self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
*args,
**kwargs,
):
# sampler_order = utils.koboldai_vars.sampler_order[:]
# if (
# len(sampler_order) < 7
# ): # Add repetition penalty at beginning if it's not present
# sampler_order = [6] + sampler_order
# for k in sampler_order:
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(HACK_currentmodel, scores)
return scores
def new_get_logits_warper(
beams: int = 1,
) -> LogitsProcessorList:
return KoboldLogitsWarperList()
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper(
beams=1,
)
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
new_sample.old_sample = transformers.GenerationMixin.sample
use_core_manipulations.sample = new_sample
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],

View File

@@ -349,57 +349,6 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
# probability distribution)
return jax.random.categorical(key, logits, -1).astype(np.uint32)
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
'''
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
'''
rpslope = jnp.int32(rpslope)
rprange = jnp.int32(rprange)
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
# Make a new array with the same length as the tokens array but with
# each element replaced by the value at the corresponding index in the
# logits array; e.g.
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
penalty_logits = jnp.take(logits, tokens)
# Repetition penalty slope
def apply_slope(carry):
repetition_penalty, rprange = carry
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
return _penalty
repetition_penalty = jax.lax.cond(
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
apply_slope,
lambda carry: jnp.full(tokens.shape, carry[0]),
(repetition_penalty, rprange),
)
# Divide positive values by repetition_penalty and multiply negative
# values by repetition_penalty (the academic publication that described
# this technique actually just only divided, but that would cause tokens
# with negative logits to become more likely, which is obviously wrong)
if koboldai_vars.use_alt_rep_pen:
penalty_logits = jnp.where(
penalty_arange >= 0,
penalty_logits - jnp.log(repetition_penalty),
penalty_logits,
)
else:
penalty_logits = jnp.where(
penalty_arange >= 0,
jnp.where(
penalty_logits > 0,
penalty_logits/repetition_penalty,
penalty_logits*repetition_penalty,
),
penalty_logits,
)
# Finally, put those penalized logit values back into their original
# positions in the logits array
return logits.at[tokens].set(penalty_logits)
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
'''

View File

@@ -38,11 +38,31 @@ file are mostly porting warper code to the torch methods.
from __future__ import annotations
import utils
import torch
import jax
import jax.numpy as jnp
import numpy as np
import tpu_mtj_backend
try:
import jax
import jax.numpy as jnp
import tpu_mtj_backend
except ImportError:
assert not utils.koboldai_vars.use_colab_tpu
def update_settings():
# This feels like a bad way to structure this
koboldai_vars = utils.koboldai_vars
Temperature.temperature = koboldai_vars.temp
TopP.top_p = koboldai_vars.top_p
TopK.top_k = koboldai_vars.top_k
TopA.top_a = koboldai_vars.top_a
TailFree.tfs = koboldai_vars.tfs
Typical.typical = koboldai_vars.typical
RepetitionPenalty.rep_pen = koboldai_vars.rep_pen
RepetitionPenalty.rep_pen_range = koboldai_vars.rep_pen_range
RepetitionPenalty.rep_pen_slope = koboldai_vars.rep_pen_slope
RepetitionPenalty.use_alt_rep_pen = koboldai_vars.use_alt_rep_pen
class Warper:
@@ -70,7 +90,7 @@ class Temperature(Warper):
@classmethod
def jax(cls, scores: jnp.array) -> jnp.array:
return scores / cls.value
return scores / cls.temperature
class TopP(Warper):
@@ -88,7 +108,7 @@ class TopP(Warper):
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
sorted_indices_to_remove = cumulative_probs <= (1 - cls.top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
@@ -109,7 +129,7 @@ class TopP(Warper):
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
# We want to remove tokens with cumulative probability higher
# than top_p
sorted_indices_to_remove = cumulative_probabilities > cls.value
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
# Don't ever remove the token with the highest logit, even if
# the probability is higher than top_p
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
@@ -358,7 +378,7 @@ class RepetitionPenalty(Warper):
use_alt_rep_pen: bool = False
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
def torch(cls, scores: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
cls.rep_pen_range = int(cls.rep_pen_range)
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)