mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Samplers: Part 2
This commit is contained in:
718
model.py
718
model.py
@@ -1,5 +1,4 @@
|
|||||||
# TODO:
|
# Before merge: please make sure to fix any TODOB4MERGE comments
|
||||||
# - Intertwine stoppers and streaming and such
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
@@ -16,12 +15,15 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
import zipfile
|
import zipfile
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import torch_lazy_loader
|
import torch_lazy_loader
|
||||||
|
|
||||||
|
import warpers
|
||||||
|
from warpers import Warper
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Embedding
|
from torch.nn import Embedding
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -32,14 +34,13 @@ from transformers import (
|
|||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
GPTNeoForCausalLM,
|
GPTNeoForCausalLM,
|
||||||
GPTNeoModel,
|
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
modeling_utils,
|
modeling_utils,
|
||||||
AutoModelForTokenClassification,
|
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
LogitsProcessorList,
|
||||||
|
LogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
@@ -399,21 +400,6 @@ def patch_transformers_generation() -> None:
|
|||||||
global transformers
|
global transformers
|
||||||
|
|
||||||
# Patch transformers to use our custom logit warpers -- Only HFTorchInferenceModel uses this
|
# Patch transformers to use our custom logit warpers -- Only HFTorchInferenceModel uses this
|
||||||
from transformers import (
|
|
||||||
LogitsProcessorList,
|
|
||||||
LogitsWarper,
|
|
||||||
LogitsProcessor,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
TemperatureLogitsWarper,
|
|
||||||
)
|
|
||||||
from warpers import (
|
|
||||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
|
||||||
TailFreeLogitsWarper,
|
|
||||||
TypicalLogitsWarper,
|
|
||||||
TopALogitsWarper,
|
|
||||||
)
|
|
||||||
|
|
||||||
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
|
||||||
old_call = cls.__call__
|
old_call = cls.__call__
|
||||||
|
|
||||||
@@ -434,343 +420,22 @@ def patch_transformers_generation() -> None:
|
|||||||
cls.__call__ = new_call
|
cls.__call__ = new_call
|
||||||
|
|
||||||
# TODO: Make samplers generic
|
# TODO: Make samplers generic
|
||||||
dynamic_processor_wrap(
|
# dynamic_processor_wrap(
|
||||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
# AdvancedRepetitionPenaltyLogitsProcessor,
|
||||||
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||||
("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
|
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
|
||||||
cond=lambda x: x[0] != 1.0,
|
# cond=lambda x: x[0] != 1.0,
|
||||||
)
|
# )
|
||||||
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||||
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||||
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||||
dynamic_processor_wrap(
|
# dynamic_processor_wrap(
|
||||||
TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
|
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
|
||||||
)
|
# )
|
||||||
dynamic_processor_wrap(
|
# dynamic_processor_wrap(
|
||||||
TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
|
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
|
||||||
)
|
# )
|
||||||
|
|
||||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _find_intersection(self, big: List, small: List) -> int:
|
|
||||||
"""Find the maximum overlap between the beginning of small and the end of big.
|
|
||||||
Return the index of the token in small following the overlap, or 0.
|
|
||||||
|
|
||||||
big: The tokens in the context (as a tensor)
|
|
||||||
small: The tokens in the phrase to bias (as a list)
|
|
||||||
|
|
||||||
Both big and small are in "oldest to newest" order.
|
|
||||||
"""
|
|
||||||
# There are asymptotically more efficient methods for determining the overlap,
|
|
||||||
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
|
||||||
# elements of big, plus small will typically be fairly short. So this naive
|
|
||||||
# approach is acceptable despite O(N^2) worst case performance.
|
|
||||||
|
|
||||||
num_small = len(small)
|
|
||||||
# The small list can only ever match against at most num_small tokens of big,
|
|
||||||
# so create a slice. Typically, this slice will be as long as small, but it
|
|
||||||
# may be shorter if the story has just started.
|
|
||||||
# We need to convert the big slice to list, since natively big is a tensor
|
|
||||||
# and tensor and list don't ever compare equal. It's better to convert here
|
|
||||||
# and then use native equality tests than to iterate repeatedly later.
|
|
||||||
big_slice = list(big[-num_small:])
|
|
||||||
|
|
||||||
# It's possible that the start token appears multiple times in small
|
|
||||||
# For example, consider the phrase:
|
|
||||||
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
|
||||||
# If we merely look for the first instance of [ fair], then we would
|
|
||||||
# generate the following output:
|
|
||||||
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
|
||||||
start = small[0]
|
|
||||||
for i, t in enumerate(big_slice):
|
|
||||||
# Strictly unnecessary, but it's marginally faster to test the first
|
|
||||||
# token before creating slices to test for a full match.
|
|
||||||
if t == start:
|
|
||||||
remaining = len(big_slice) - i
|
|
||||||
if big_slice[i:] == small[:remaining]:
|
|
||||||
# We found a match. If the small phrase has any remaining tokens
|
|
||||||
# then return the index of the next token.
|
|
||||||
if remaining < num_small:
|
|
||||||
return remaining
|
|
||||||
# In this case, the entire small phrase matched, so start over.
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# There were no matches, so just begin at the beginning.
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
|
||||||
"""Determines if a phrase should be tampered with from the left in
|
|
||||||
the "soft" token encoding mode."""
|
|
||||||
|
|
||||||
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_token_sequence(self, phrase: str) -> List[List]:
|
|
||||||
"""Convert the phrase string into a list of encoded biases, each
|
|
||||||
one being a list of tokens. How this is done is determined by the
|
|
||||||
phrase's format:
|
|
||||||
|
|
||||||
- If the phrase is surrounded by square brackets ([]), the tokens
|
|
||||||
will be the phrase split by commas (,). If a "token" isn't
|
|
||||||
actually a number, it will be skipped. NOTE: Tokens output by
|
|
||||||
this may not be in the model's vocabulary, and such tokens
|
|
||||||
should be ignored later in the pipeline.
|
|
||||||
- If the phrase is surrounded by curly brackets ({}), the phrase
|
|
||||||
will be directly encoded with no synonym biases and no fancy
|
|
||||||
tricks.
|
|
||||||
- Otherwise, the phrase will be encoded, with close deviations
|
|
||||||
being included as synonym biases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: Cache these tokens, invalidate when model or bias is
|
|
||||||
# changed.
|
|
||||||
|
|
||||||
# Handle direct token id input
|
|
||||||
if phrase.startswith("[") and phrase.endswith("]"):
|
|
||||||
no_brackets = phrase[1:-1]
|
|
||||||
ret = []
|
|
||||||
for token_id in no_brackets.split(","):
|
|
||||||
try:
|
|
||||||
ret.append(int(token_id))
|
|
||||||
except ValueError:
|
|
||||||
# Ignore non-numbers. Rascals!
|
|
||||||
pass
|
|
||||||
return [ret]
|
|
||||||
|
|
||||||
# Handle direct phrases
|
|
||||||
if phrase.startswith("{") and phrase.endswith("}"):
|
|
||||||
no_brackets = phrase[1:-1]
|
|
||||||
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
|
|
||||||
|
|
||||||
# Handle untamperable phrases
|
|
||||||
if not self._allow_leftwards_tampering(phrase):
|
|
||||||
return [HACK_currentmodel.tokenizer.encode(phrase)]
|
|
||||||
|
|
||||||
# Handle slight alterations to original phrase
|
|
||||||
phrase = phrase.strip(" ")
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for alt_phrase in [phrase, f" {phrase}"]:
|
|
||||||
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
|
||||||
# TODO: Different "bias slopes"?
|
|
||||||
|
|
||||||
ret = {}
|
|
||||||
for phrase, _bias in utils.koboldai_vars.biases.items():
|
|
||||||
bias_score, completion_threshold = _bias
|
|
||||||
token_seqs = self._get_token_sequence(phrase)
|
|
||||||
variant_deltas = {}
|
|
||||||
for token_seq in token_seqs:
|
|
||||||
bias_index = self._find_intersection(input_ids, token_seq)
|
|
||||||
|
|
||||||
# Ensure completion after completion_threshold tokens
|
|
||||||
# Only provide a positive bias when the base bias score is positive.
|
|
||||||
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
|
||||||
bias_score = 999
|
|
||||||
|
|
||||||
token_to_bias = token_seq[bias_index]
|
|
||||||
variant_deltas[token_to_bias] = bias_score
|
|
||||||
|
|
||||||
# If multiple phrases bias the same token, add the modifiers
|
|
||||||
# together. This should NOT be applied to automatic variants
|
|
||||||
for token_to_bias, bias_score in variant_deltas.items():
|
|
||||||
if token_to_bias in ret:
|
|
||||||
ret[token_to_bias] += bias_score
|
|
||||||
else:
|
|
||||||
ret[token_to_bias] = bias_score
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
assert scores.ndim == 2
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
|
|
||||||
scores_shape = scores.shape
|
|
||||||
|
|
||||||
for batch in range(scores_shape[0]):
|
|
||||||
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
|
||||||
scores[batch][token] += bias
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
class LuaLogitsProcessor(LogitsProcessor):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
assert scores.ndim == 2
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
self.regeneration_required = False
|
|
||||||
self.halt = False
|
|
||||||
|
|
||||||
if utils.koboldai_vars.standalone:
|
|
||||||
return scores
|
|
||||||
|
|
||||||
scores_shape = scores.shape
|
|
||||||
scores_list = scores.tolist()
|
|
||||||
utils.koboldai_vars.lua_koboldbridge.logits = (
|
|
||||||
utils.koboldai_vars.lua_state.table()
|
|
||||||
)
|
|
||||||
for r, row in enumerate(scores_list):
|
|
||||||
utils.koboldai_vars.lua_koboldbridge.logits[
|
|
||||||
r + 1
|
|
||||||
] = utils.koboldai_vars.lua_state.table(*row)
|
|
||||||
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
|
||||||
|
|
||||||
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
|
||||||
|
|
||||||
scores = torch.tensor(
|
|
||||||
tuple(
|
|
||||||
tuple(row.values())
|
|
||||||
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
|
||||||
),
|
|
||||||
device=scores.device,
|
|
||||||
dtype=scores.dtype,
|
|
||||||
)
|
|
||||||
assert scores.shape == scores_shape
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
def visualize_probabilities(
|
|
||||||
model: InferenceModel,
|
|
||||||
scores: torch.FloatTensor,
|
|
||||||
) -> None:
|
|
||||||
assert scores.ndim == 2
|
|
||||||
|
|
||||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not utils.koboldai_vars.show_probs:
|
|
||||||
return scores
|
|
||||||
|
|
||||||
option_offset = 0
|
|
||||||
if (
|
|
||||||
utils.koboldai_vars.actions.action_count + 1
|
|
||||||
in utils.koboldai_vars.actions.actions
|
|
||||||
):
|
|
||||||
for x in range(
|
|
||||||
len(
|
|
||||||
utils.koboldai_vars.actions.actions[
|
|
||||||
utils.koboldai_vars.actions.action_count + 1
|
|
||||||
]["Options"]
|
|
||||||
)
|
|
||||||
):
|
|
||||||
option = utils.koboldai_vars.actions.actions[
|
|
||||||
utils.koboldai_vars.actions.action_count + 1
|
|
||||||
]["Options"][x]
|
|
||||||
if option["Pinned"] or option["Previous Selection"] or option["Edited"]:
|
|
||||||
option_offset = x + 1
|
|
||||||
batch_offset = (
|
|
||||||
int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt)
|
|
||||||
if utils.koboldai_vars.alt_multi_gen
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
for batch_index, batch in enumerate(scores):
|
|
||||||
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
|
||||||
|
|
||||||
token_prob_info = []
|
|
||||||
for token_id, score in sorted(
|
|
||||||
enumerate(probs), key=lambda x: x[1], reverse=True
|
|
||||||
)[:8]:
|
|
||||||
token_prob_info.append(
|
|
||||||
{
|
|
||||||
"tokenId": token_id,
|
|
||||||
"decoded": utils.decodenewlines(
|
|
||||||
model.tokenizer.decode(token_id)
|
|
||||||
),
|
|
||||||
"score": float(score),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if utils.koboldai_vars.numseqs == 1:
|
|
||||||
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
|
||||||
else:
|
|
||||||
utils.koboldai_vars.actions.set_option_probabilities(
|
|
||||||
token_prob_info, batch_index + option_offset + batch_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
|
||||||
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
|
|
||||||
processors.insert(0, LuaLogitsProcessor())
|
|
||||||
processors.append(PhraseBiasLogitsProcessor())
|
|
||||||
return processors
|
|
||||||
|
|
||||||
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
|
||||||
new_get_logits_processor.old_get_logits_processor = (
|
|
||||||
transformers.GenerationMixin._get_logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
|
||||||
def __init__(self, beams: int = 1, **kwargs):
|
|
||||||
self.__warper_list: List[LogitsWarper] = []
|
|
||||||
self.__warper_list.append(
|
|
||||||
TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1))
|
|
||||||
)
|
|
||||||
self.__warper_list.append(
|
|
||||||
TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1))
|
|
||||||
)
|
|
||||||
self.__warper_list.append(
|
|
||||||
TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1))
|
|
||||||
)
|
|
||||||
self.__warper_list.append(
|
|
||||||
TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1))
|
|
||||||
)
|
|
||||||
self.__warper_list.append(
|
|
||||||
TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1))
|
|
||||||
)
|
|
||||||
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
|
|
||||||
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
scores: torch.FloatTensor,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
sampler_order = utils.koboldai_vars.sampler_order[:]
|
|
||||||
if (
|
|
||||||
len(sampler_order) < 7
|
|
||||||
): # Add repetition penalty at beginning if it's not present
|
|
||||||
sampler_order = [6] + sampler_order
|
|
||||||
for k in sampler_order:
|
|
||||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
|
||||||
visualize_probabilities(HACK_currentmodel, scores)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def new_get_logits_warper(
|
|
||||||
beams: int = 1,
|
|
||||||
) -> LogitsProcessorList:
|
|
||||||
return KoboldLogitsWarperList(beams=beams)
|
|
||||||
|
|
||||||
def new_sample(self, *args, **kwargs):
|
|
||||||
assert kwargs.pop("logits_warper", None) is not None
|
|
||||||
kwargs["logits_warper"] = new_get_logits_warper(
|
|
||||||
beams=1,
|
|
||||||
)
|
|
||||||
if (utils.koboldai_vars.newlinemode == "s") or (
|
|
||||||
utils.koboldai_vars.newlinemode == "ns"
|
|
||||||
):
|
|
||||||
kwargs["eos_token_id"] = -1
|
|
||||||
kwargs.setdefault("pad_token_id", 2)
|
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
|
||||||
|
|
||||||
new_sample.old_sample = transformers.GenerationMixin.sample
|
|
||||||
use_core_manipulations.sample = new_sample
|
|
||||||
|
|
||||||
# Allow bad words filter to ban <|endoftext|> token
|
# Allow bad words filter to ban <|endoftext|> token
|
||||||
import transformers.generation.logits_process
|
import transformers.generation.logits_process
|
||||||
@@ -852,10 +517,12 @@ class InferenceModel:
|
|||||||
global HACK_currentmodel
|
global HACK_currentmodel
|
||||||
HACK_currentmodel = self
|
HACK_currentmodel = self
|
||||||
|
|
||||||
|
print(self.raw_generate("Hi guys,", 20).__dict__)
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _load(self, save_model: bool, inital_load: bool) -> None:
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_tokenizer(self, location: str):
|
def _get_tokenizer(self, location: str):
|
||||||
@@ -1520,8 +1187,22 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
)
|
)
|
||||||
self._old_stopping_criteria = None
|
self._old_stopping_criteria = None
|
||||||
|
|
||||||
|
def _apply_warpers(
|
||||||
|
self, scores: torch.Tensor, input_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
warpers.update_settings()
|
||||||
|
for sid in utils.koboldai_vars.sampler_order:
|
||||||
|
warper = Warper.from_id(sid)
|
||||||
|
if warper == warpers.RepetitionPenalty:
|
||||||
|
# Rep pen needs more data than other samplers
|
||||||
|
print("is rep:", warper)
|
||||||
|
scores = warper.torch(scores, input_ids=input_ids)
|
||||||
|
else:
|
||||||
|
print("aint rep:", warper)
|
||||||
|
scores = warper.torch(scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
print("HELLLOOOOOOOOOOOOOOOOOOOOOOOOOOO")
|
|
||||||
# Patch stopping_criteria
|
# Patch stopping_criteria
|
||||||
|
|
||||||
class PTHStopper(StoppingCriteria):
|
class PTHStopper(StoppingCriteria):
|
||||||
@@ -1551,6 +1232,323 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
|
|
||||||
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
||||||
|
|
||||||
|
# Patch logitswarpers
|
||||||
|
|
||||||
|
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _find_intersection(self, big: List, small: List) -> int:
|
||||||
|
"""Find the maximum overlap between the beginning of small and the end of big.
|
||||||
|
Return the index of the token in small following the overlap, or 0.
|
||||||
|
|
||||||
|
big: The tokens in the context (as a tensor)
|
||||||
|
small: The tokens in the phrase to bias (as a list)
|
||||||
|
|
||||||
|
Both big and small are in "oldest to newest" order.
|
||||||
|
"""
|
||||||
|
# There are asymptotically more efficient methods for determining the overlap,
|
||||||
|
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
||||||
|
# elements of big, plus small will typically be fairly short. So this naive
|
||||||
|
# approach is acceptable despite O(N^2) worst case performance.
|
||||||
|
|
||||||
|
num_small = len(small)
|
||||||
|
# The small list can only ever match against at most num_small tokens of big,
|
||||||
|
# so create a slice. Typically, this slice will be as long as small, but it
|
||||||
|
# may be shorter if the story has just started.
|
||||||
|
# We need to convert the big slice to list, since natively big is a tensor
|
||||||
|
# and tensor and list don't ever compare equal. It's better to convert here
|
||||||
|
# and then use native equality tests than to iterate repeatedly later.
|
||||||
|
big_slice = list(big[-num_small:])
|
||||||
|
|
||||||
|
# It's possible that the start token appears multiple times in small
|
||||||
|
# For example, consider the phrase:
|
||||||
|
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
||||||
|
# If we merely look for the first instance of [ fair], then we would
|
||||||
|
# generate the following output:
|
||||||
|
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
||||||
|
start = small[0]
|
||||||
|
for i, t in enumerate(big_slice):
|
||||||
|
# Strictly unnecessary, but it's marginally faster to test the first
|
||||||
|
# token before creating slices to test for a full match.
|
||||||
|
if t == start:
|
||||||
|
remaining = len(big_slice) - i
|
||||||
|
if big_slice[i:] == small[:remaining]:
|
||||||
|
# We found a match. If the small phrase has any remaining tokens
|
||||||
|
# then return the index of the next token.
|
||||||
|
if remaining < num_small:
|
||||||
|
return remaining
|
||||||
|
# In this case, the entire small phrase matched, so start over.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# There were no matches, so just begin at the beginning.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
||||||
|
"""Determines if a phrase should be tampered with from the left in
|
||||||
|
the "soft" token encoding mode."""
|
||||||
|
|
||||||
|
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_token_sequence(self, phrase: str) -> List[List]:
|
||||||
|
"""Convert the phrase string into a list of encoded biases, each
|
||||||
|
one being a list of tokens. How this is done is determined by the
|
||||||
|
phrase's format:
|
||||||
|
|
||||||
|
- If the phrase is surrounded by square brackets ([]), the tokens
|
||||||
|
will be the phrase split by commas (,). If a "token" isn't
|
||||||
|
actually a number, it will be skipped. NOTE: Tokens output by
|
||||||
|
this may not be in the model's vocabulary, and such tokens
|
||||||
|
should be ignored later in the pipeline.
|
||||||
|
- If the phrase is surrounded by curly brackets ({}), the phrase
|
||||||
|
will be directly encoded with no synonym biases and no fancy
|
||||||
|
tricks.
|
||||||
|
- Otherwise, the phrase will be encoded, with close deviations
|
||||||
|
being included as synonym biases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Cache these tokens, invalidate when model or bias is
|
||||||
|
# changed.
|
||||||
|
|
||||||
|
# Handle direct token id input
|
||||||
|
if phrase.startswith("[") and phrase.endswith("]"):
|
||||||
|
no_brackets = phrase[1:-1]
|
||||||
|
ret = []
|
||||||
|
for token_id in no_brackets.split(","):
|
||||||
|
try:
|
||||||
|
ret.append(int(token_id))
|
||||||
|
except ValueError:
|
||||||
|
# Ignore non-numbers. Rascals!
|
||||||
|
pass
|
||||||
|
return [ret]
|
||||||
|
|
||||||
|
# Handle direct phrases
|
||||||
|
if phrase.startswith("{") and phrase.endswith("}"):
|
||||||
|
no_brackets = phrase[1:-1]
|
||||||
|
return [HACK_currentmodel.tokenizer.encode(no_brackets)]
|
||||||
|
|
||||||
|
# Handle untamperable phrases
|
||||||
|
if not self._allow_leftwards_tampering(phrase):
|
||||||
|
return [HACK_currentmodel.tokenizer.encode(phrase)]
|
||||||
|
|
||||||
|
# Handle slight alterations to original phrase
|
||||||
|
phrase = phrase.strip(" ")
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
for alt_phrase in [phrase, f" {phrase}"]:
|
||||||
|
ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||||
|
# TODO: Different "bias slopes"?
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for phrase, _bias in utils.koboldai_vars.biases.items():
|
||||||
|
bias_score, completion_threshold = _bias
|
||||||
|
token_seqs = self._get_token_sequence(phrase)
|
||||||
|
variant_deltas = {}
|
||||||
|
for token_seq in token_seqs:
|
||||||
|
bias_index = self._find_intersection(input_ids, token_seq)
|
||||||
|
|
||||||
|
# Ensure completion after completion_threshold tokens
|
||||||
|
# Only provide a positive bias when the base bias score is positive.
|
||||||
|
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
||||||
|
bias_score = 999
|
||||||
|
|
||||||
|
token_to_bias = token_seq[bias_index]
|
||||||
|
variant_deltas[token_to_bias] = bias_score
|
||||||
|
|
||||||
|
# If multiple phrases bias the same token, add the modifiers
|
||||||
|
# together. This should NOT be applied to automatic variants
|
||||||
|
for token_to_bias, bias_score in variant_deltas.items():
|
||||||
|
if token_to_bias in ret:
|
||||||
|
ret[token_to_bias] += bias_score
|
||||||
|
else:
|
||||||
|
ret[token_to_bias] = bias_score
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
assert scores.ndim == 2
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
scores_shape = scores.shape
|
||||||
|
|
||||||
|
for batch in range(scores_shape[0]):
|
||||||
|
for token, bias in self._get_biased_tokens(
|
||||||
|
input_ids[batch]
|
||||||
|
).items():
|
||||||
|
scores[batch][token] += bias
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
class LuaLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
assert scores.ndim == 2
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
self.regeneration_required = False
|
||||||
|
self.halt = False
|
||||||
|
|
||||||
|
if utils.koboldai_vars.standalone:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
scores_shape = scores.shape
|
||||||
|
scores_list = scores.tolist()
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||||
|
utils.koboldai_vars.lua_state.table()
|
||||||
|
)
|
||||||
|
for r, row in enumerate(scores_list):
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||||
|
r + 1
|
||||||
|
] = utils.koboldai_vars.lua_state.table(*row)
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||||
|
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||||
|
|
||||||
|
scores = torch.tensor(
|
||||||
|
tuple(
|
||||||
|
tuple(row.values())
|
||||||
|
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||||
|
),
|
||||||
|
device=scores.device,
|
||||||
|
dtype=scores.dtype,
|
||||||
|
)
|
||||||
|
assert scores.shape == scores_shape
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
def visualize_probabilities(
|
||||||
|
model: InferenceModel,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
) -> None:
|
||||||
|
assert scores.ndim == 2
|
||||||
|
|
||||||
|
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.show_probs:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
option_offset = 0
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.actions.action_count + 1
|
||||||
|
in utils.koboldai_vars.actions.actions
|
||||||
|
):
|
||||||
|
for x in range(
|
||||||
|
len(
|
||||||
|
utils.koboldai_vars.actions.actions[
|
||||||
|
utils.koboldai_vars.actions.action_count + 1
|
||||||
|
]["Options"]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
option = utils.koboldai_vars.actions.actions[
|
||||||
|
utils.koboldai_vars.actions.action_count + 1
|
||||||
|
]["Options"][x]
|
||||||
|
if (
|
||||||
|
option["Pinned"]
|
||||||
|
or option["Previous Selection"]
|
||||||
|
or option["Edited"]
|
||||||
|
):
|
||||||
|
option_offset = x + 1
|
||||||
|
batch_offset = (
|
||||||
|
int(
|
||||||
|
(utils.koboldai_vars.generated_tkns - 1)
|
||||||
|
/ utils.koboldai_vars.genamt
|
||||||
|
)
|
||||||
|
if utils.koboldai_vars.alt_multi_gen
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
for batch_index, batch in enumerate(scores):
|
||||||
|
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
||||||
|
|
||||||
|
token_prob_info = []
|
||||||
|
for token_id, score in sorted(
|
||||||
|
enumerate(probs), key=lambda x: x[1], reverse=True
|
||||||
|
)[:8]:
|
||||||
|
token_prob_info.append(
|
||||||
|
{
|
||||||
|
"tokenId": token_id,
|
||||||
|
"decoded": utils.decodenewlines(
|
||||||
|
model.tokenizer.decode(token_id)
|
||||||
|
),
|
||||||
|
"score": float(score),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.numseqs == 1:
|
||||||
|
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
||||||
|
else:
|
||||||
|
utils.koboldai_vars.actions.set_option_probabilities(
|
||||||
|
token_prob_info, batch_index + option_offset + batch_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||||
|
processors = new_get_logits_processor.old_get_logits_processor(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
# TODOB4MERGE: These two
|
||||||
|
# processors.insert(0, LuaLogitsProcessor())
|
||||||
|
# processors.append(PhraseBiasLogitsProcessor())
|
||||||
|
return processors
|
||||||
|
|
||||||
|
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||||
|
new_get_logits_processor.old_get_logits_processor = (
|
||||||
|
transformers.GenerationMixin._get_logits_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
lw_self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||||
|
# if (
|
||||||
|
# len(sampler_order) < 7
|
||||||
|
# ): # Add repetition penalty at beginning if it's not present
|
||||||
|
# sampler_order = [6] + sampler_order
|
||||||
|
# for k in sampler_order:
|
||||||
|
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
|
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||||
|
visualize_probabilities(HACK_currentmodel, scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def new_get_logits_warper(
|
||||||
|
beams: int = 1,
|
||||||
|
) -> LogitsProcessorList:
|
||||||
|
return KoboldLogitsWarperList()
|
||||||
|
|
||||||
|
def new_sample(self, *args, **kwargs):
|
||||||
|
assert kwargs.pop("logits_warper", None) is not None
|
||||||
|
kwargs["logits_warper"] = new_get_logits_warper(
|
||||||
|
beams=1,
|
||||||
|
)
|
||||||
|
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
|
||||||
|
kwargs["eos_token_id"] = -1
|
||||||
|
kwargs.setdefault("pad_token_id", 2)
|
||||||
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
|
|
||||||
|
new_sample.old_sample = transformers.GenerationMixin.sample
|
||||||
|
use_core_manipulations.sample = new_sample
|
||||||
|
|
||||||
def _raw_generate(
|
def _raw_generate(
|
||||||
self,
|
self,
|
||||||
prompt_tokens: Union[List[int], torch.Tensor],
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
@@ -349,57 +349,6 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra
|
|||||||
# probability distribution)
|
# probability distribution)
|
||||||
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
|
|
||||||
'''
|
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
|
||||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
|
||||||
'''
|
|
||||||
rpslope = jnp.int32(rpslope)
|
|
||||||
rprange = jnp.int32(rprange)
|
|
||||||
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
|
|
||||||
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
|
|
||||||
# Make a new array with the same length as the tokens array but with
|
|
||||||
# each element replaced by the value at the corresponding index in the
|
|
||||||
# logits array; e.g.
|
|
||||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
|
||||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
|
||||||
penalty_logits = jnp.take(logits, tokens)
|
|
||||||
# Repetition penalty slope
|
|
||||||
def apply_slope(carry):
|
|
||||||
repetition_penalty, rprange = carry
|
|
||||||
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
|
|
||||||
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
|
|
||||||
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
|
|
||||||
return _penalty
|
|
||||||
repetition_penalty = jax.lax.cond(
|
|
||||||
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
|
|
||||||
apply_slope,
|
|
||||||
lambda carry: jnp.full(tokens.shape, carry[0]),
|
|
||||||
(repetition_penalty, rprange),
|
|
||||||
)
|
|
||||||
# Divide positive values by repetition_penalty and multiply negative
|
|
||||||
# values by repetition_penalty (the academic publication that described
|
|
||||||
# this technique actually just only divided, but that would cause tokens
|
|
||||||
# with negative logits to become more likely, which is obviously wrong)
|
|
||||||
if koboldai_vars.use_alt_rep_pen:
|
|
||||||
penalty_logits = jnp.where(
|
|
||||||
penalty_arange >= 0,
|
|
||||||
penalty_logits - jnp.log(repetition_penalty),
|
|
||||||
penalty_logits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
penalty_logits = jnp.where(
|
|
||||||
penalty_arange >= 0,
|
|
||||||
jnp.where(
|
|
||||||
penalty_logits > 0,
|
|
||||||
penalty_logits/repetition_penalty,
|
|
||||||
penalty_logits*repetition_penalty,
|
|
||||||
),
|
|
||||||
penalty_logits,
|
|
||||||
)
|
|
||||||
# Finally, put those penalized logit values back into their original
|
|
||||||
# positions in the logits array
|
|
||||||
return logits.at[tokens].set(penalty_logits)
|
|
||||||
|
|
||||||
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
|
||||||
'''
|
'''
|
||||||
|
30
warpers.py
30
warpers.py
@@ -38,11 +38,31 @@ file are mostly porting warper code to the torch methods.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import utils
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
|
except ImportError:
|
||||||
|
assert not utils.koboldai_vars.use_colab_tpu
|
||||||
|
|
||||||
|
|
||||||
|
def update_settings():
|
||||||
|
# This feels like a bad way to structure this
|
||||||
|
koboldai_vars = utils.koboldai_vars
|
||||||
|
Temperature.temperature = koboldai_vars.temp
|
||||||
|
TopP.top_p = koboldai_vars.top_p
|
||||||
|
TopK.top_k = koboldai_vars.top_k
|
||||||
|
TopA.top_a = koboldai_vars.top_a
|
||||||
|
TailFree.tfs = koboldai_vars.tfs
|
||||||
|
Typical.typical = koboldai_vars.typical
|
||||||
|
RepetitionPenalty.rep_pen = koboldai_vars.rep_pen
|
||||||
|
RepetitionPenalty.rep_pen_range = koboldai_vars.rep_pen_range
|
||||||
|
RepetitionPenalty.rep_pen_slope = koboldai_vars.rep_pen_slope
|
||||||
|
RepetitionPenalty.use_alt_rep_pen = koboldai_vars.use_alt_rep_pen
|
||||||
|
|
||||||
|
|
||||||
class Warper:
|
class Warper:
|
||||||
@@ -70,7 +90,7 @@ class Temperature(Warper):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def jax(cls, scores: jnp.array) -> jnp.array:
|
def jax(cls, scores: jnp.array) -> jnp.array:
|
||||||
return scores / cls.value
|
return scores / cls.temperature
|
||||||
|
|
||||||
|
|
||||||
class TopP(Warper):
|
class TopP(Warper):
|
||||||
@@ -88,7 +108,7 @@ class TopP(Warper):
|
|||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = cumulative_probs <= (1 - cls.value)
|
sorted_indices_to_remove = cumulative_probs <= (1 - cls.top_p)
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
@@ -109,7 +129,7 @@ class TopP(Warper):
|
|||||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
||||||
# We want to remove tokens with cumulative probability higher
|
# We want to remove tokens with cumulative probability higher
|
||||||
# than top_p
|
# than top_p
|
||||||
sorted_indices_to_remove = cumulative_probabilities > cls.value
|
sorted_indices_to_remove = cumulative_probabilities > cls.top_p
|
||||||
# Don't ever remove the token with the highest logit, even if
|
# Don't ever remove the token with the highest logit, even if
|
||||||
# the probability is higher than top_p
|
# the probability is higher than top_p
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
||||||
@@ -358,7 +378,7 @@ class RepetitionPenalty(Warper):
|
|||||||
use_alt_rep_pen: bool = False
|
use_alt_rep_pen: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
def torch(cls, scores: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
cls.rep_pen_range = int(cls.rep_pen_range)
|
cls.rep_pen_range = int(cls.rep_pen_range)
|
||||||
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
|
clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user