mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Modeling: Fix logits processors (probs, biasing, lua)
This commit is contained in:
@@ -13,6 +13,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
)
|
||||
from modeling.tokenizer import GenericTokenizer
|
||||
from modeling import logits_processors
|
||||
|
||||
import utils
|
||||
|
||||
@@ -160,6 +161,11 @@ class InferenceModel:
|
||||
self.gen_state = {}
|
||||
self.post_token_hooks = []
|
||||
self.stopper_hooks = []
|
||||
self.logits_processors = [
|
||||
logits_processors.LuaIntegration(),
|
||||
logits_processors.PhraseBiasLogitsProcessor(),
|
||||
logits_processors.ProbabilityVisualization(),
|
||||
]
|
||||
self.tokenizer = None
|
||||
self.capabilties = ModelCapabilities()
|
||||
|
||||
|
@@ -36,7 +36,6 @@ from modeling.inference_models.hf import HFInferenceModel
|
||||
from modeling.inference_model import (
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
InferenceModel,
|
||||
ModelCapabilities,
|
||||
use_core_manipulations,
|
||||
)
|
||||
@@ -104,6 +103,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
else:
|
||||
scores = warper.torch(scores)
|
||||
|
||||
assert scores is not None, f"Scores are None; warper '{warper}' is to blame"
|
||||
|
||||
if LOG_SAMPLER_NO_EFFECT:
|
||||
if torch.equal(pre, scores):
|
||||
logger.info(warper, "had no effect on the scores.")
|
||||
@@ -166,274 +167,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
# Patch logitswarpers
|
||||
|
||||
class PhraseBiasLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
"""Find the maximum overlap between the beginning of small and the end of big.
|
||||
Return the index of the token in small following the overlap, or 0.
|
||||
|
||||
big: The tokens in the context (as a tensor)
|
||||
small: The tokens in the phrase to bias (as a list)
|
||||
|
||||
Both big and small are in "oldest to newest" order.
|
||||
"""
|
||||
# There are asymptotically more efficient methods for determining the overlap,
|
||||
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
||||
# elements of big, plus small will typically be fairly short. So this naive
|
||||
# approach is acceptable despite O(N^2) worst case performance.
|
||||
|
||||
num_small = len(small)
|
||||
# The small list can only ever match against at most num_small tokens of big,
|
||||
# so create a slice. Typically, this slice will be as long as small, but it
|
||||
# may be shorter if the story has just started.
|
||||
# We need to convert the big slice to list, since natively big is a tensor
|
||||
# and tensor and list don't ever compare equal. It's better to convert here
|
||||
# and then use native equality tests than to iterate repeatedly later.
|
||||
big_slice = list(big[-num_small:])
|
||||
|
||||
# It's possible that the start token appears multiple times in small
|
||||
# For example, consider the phrase:
|
||||
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
||||
# If we merely look for the first instance of [ fair], then we would
|
||||
# generate the following output:
|
||||
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
||||
start = small[0]
|
||||
for i, t in enumerate(big_slice):
|
||||
# Strictly unnecessary, but it's marginally faster to test the first
|
||||
# token before creating slices to test for a full match.
|
||||
if t == start:
|
||||
remaining = len(big_slice) - i
|
||||
if big_slice[i:] == small[:remaining]:
|
||||
# We found a match. If the small phrase has any remaining tokens
|
||||
# then return the index of the next token.
|
||||
if remaining < num_small:
|
||||
return remaining
|
||||
# In this case, the entire small phrase matched, so start over.
|
||||
return 0
|
||||
|
||||
# There were no matches, so just begin at the beginning.
|
||||
return 0
|
||||
|
||||
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
||||
"""Determines if a phrase should be tampered with from the left in
|
||||
the "soft" token encoding mode."""
|
||||
|
||||
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_token_sequence(self, phrase: str) -> List[List]:
|
||||
"""Convert the phrase string into a list of encoded biases, each
|
||||
one being a list of tokens. How this is done is determined by the
|
||||
phrase's format:
|
||||
|
||||
- If the phrase is surrounded by square brackets ([]), the tokens
|
||||
will be the phrase split by commas (,). If a "token" isn't
|
||||
actually a number, it will be skipped. NOTE: Tokens output by
|
||||
this may not be in the model's vocabulary, and such tokens
|
||||
should be ignored later in the pipeline.
|
||||
- If the phrase is surrounded by curly brackets ({}), the phrase
|
||||
will be directly encoded with no synonym biases and no fancy
|
||||
tricks.
|
||||
- Otherwise, the phrase will be encoded, with close deviations
|
||||
being included as synonym biases.
|
||||
"""
|
||||
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
|
||||
# Handle direct token id input
|
||||
if phrase.startswith("[") and phrase.endswith("]"):
|
||||
no_brackets = phrase[1:-1]
|
||||
ret = []
|
||||
for token_id in no_brackets.split(","):
|
||||
try:
|
||||
ret.append(int(token_id))
|
||||
except ValueError:
|
||||
# Ignore non-numbers. Rascals!
|
||||
pass
|
||||
return [ret]
|
||||
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [m_self.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [m_self.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(m_self.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in utils.koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
token_seqs = self._get_token_sequence(phrase)
|
||||
variant_deltas = {}
|
||||
for token_seq in token_seqs:
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
# Only provide a positive bias when the base bias score is positive.
|
||||
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
variant_deltas[token_to_bias] = bias_score
|
||||
|
||||
# If multiple phrases bias the same token, add the modifiers
|
||||
# together. This should NOT be applied to automatic variants
|
||||
for token_to_bias, bias_score in variant_deltas.items():
|
||||
if token_to_bias in ret:
|
||||
ret[token_to_bias] += bias_score
|
||||
else:
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(
|
||||
input_ids[batch]
|
||||
).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
||||
|
||||
class LuaLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
if utils.koboldai_vars.standalone:
|
||||
return scores
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||
utils.koboldai_vars.lua_state.table()
|
||||
)
|
||||
for r, row in enumerate(scores_list):
|
||||
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||
r + 1
|
||||
] = utils.koboldai_vars.lua_state.table(*row)
|
||||
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
scores = torch.Tensor(
|
||||
tuple(
|
||||
tuple(row.values())
|
||||
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||
),
|
||||
device=scores.device,
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
def visualize_probabilities(
|
||||
model: InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
assert scores.ndim == 2
|
||||
|
||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||
return
|
||||
|
||||
if not utils.koboldai_vars.show_probs:
|
||||
return scores
|
||||
|
||||
option_offset = 0
|
||||
if (
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
in utils.koboldai_vars.actions.actions
|
||||
):
|
||||
for x in range(
|
||||
len(
|
||||
utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"]
|
||||
)
|
||||
):
|
||||
option = utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"][x]
|
||||
if (
|
||||
option["Pinned"]
|
||||
or option["Previous Selection"]
|
||||
or option["Edited"]
|
||||
):
|
||||
option_offset = x + 1
|
||||
batch_offset = (
|
||||
int(
|
||||
(utils.koboldai_vars.generated_tkns - 1)
|
||||
/ utils.koboldai_vars.genamt
|
||||
)
|
||||
if utils.koboldai_vars.alt_multi_gen
|
||||
else 0
|
||||
)
|
||||
for batch_index, batch in enumerate(scores):
|
||||
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(
|
||||
enumerate(probs), key=lambda x: x[1], reverse=True
|
||||
)[:8]:
|
||||
token_prob_info.append(
|
||||
{
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(
|
||||
model.tokenizer.decode(token_id)
|
||||
),
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.numseqs == 1:
|
||||
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
||||
else:
|
||||
utils.koboldai_vars.actions.set_option_probabilities(
|
||||
token_prob_info, batch_index + option_offset + batch_offset
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
|
||||
processors = new_get_logits_processor.old_get_logits_processor(
|
||||
*args, **kwargs
|
||||
)
|
||||
# TODOB4MERGE: These two
|
||||
# processors.insert(0, LuaLogitsProcessor())
|
||||
# processors.append(PhraseBiasLogitsProcessor())
|
||||
return processors
|
||||
|
||||
use_core_manipulations.get_logits_processor = new_get_logits_processor
|
||||
@@ -453,7 +190,12 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
**kwargs,
|
||||
):
|
||||
scores = m_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(m_self, scores)
|
||||
|
||||
for processor in m_self.logits_processors:
|
||||
scores = processor(m_self, scores=scores, input_ids=input_ids)
|
||||
assert (
|
||||
scores is not None
|
||||
), f"Scores are None; processor '{processor}' is to blame"
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
|
272
modeling/logits_processors.py
Normal file
272
modeling/logits_processors.py
Normal file
@@ -0,0 +1,272 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import utils
|
||||
|
||||
# Weird annotations to avoid cyclic import
|
||||
from modeling import inference_model
|
||||
|
||||
|
||||
class ProbabilityVisualization:
|
||||
def __call__(
|
||||
self,
|
||||
model: inference_model.InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
input_ids: torch.longLongTensor,
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
|
||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||
return scores
|
||||
|
||||
if not utils.koboldai_vars.show_probs:
|
||||
return scores
|
||||
|
||||
option_offset = 0
|
||||
if (
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
in utils.koboldai_vars.actions.actions
|
||||
):
|
||||
for x in range(
|
||||
len(
|
||||
utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"]
|
||||
)
|
||||
):
|
||||
option = utils.koboldai_vars.actions.actions[
|
||||
utils.koboldai_vars.actions.action_count + 1
|
||||
]["Options"][x]
|
||||
if option["Pinned"] or option["Previous Selection"] or option["Edited"]:
|
||||
option_offset = x + 1
|
||||
batch_offset = (
|
||||
int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt)
|
||||
if utils.koboldai_vars.alt_multi_gen
|
||||
else 0
|
||||
)
|
||||
for batch_index, batch in enumerate(scores):
|
||||
probs = F.softmax(batch, dim=-1).cpu().numpy()
|
||||
|
||||
token_prob_info = []
|
||||
for token_id, score in sorted(
|
||||
enumerate(probs), key=lambda x: x[1], reverse=True
|
||||
)[:8]:
|
||||
token_prob_info.append(
|
||||
{
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(
|
||||
model.tokenizer.decode(token_id)
|
||||
),
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.numseqs == 1:
|
||||
utils.koboldai_vars.actions.set_probabilities(token_prob_info)
|
||||
else:
|
||||
utils.koboldai_vars.actions.set_option_probabilities(
|
||||
token_prob_info, batch_index + option_offset + batch_offset
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class LuaIntegration:
|
||||
def __call__(
|
||||
self,
|
||||
model: inference_model.InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
input_ids: torch.longLongTensor,
|
||||
) -> torch.FloatTensor:
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
model.gen_state["regeneration_required"] = False
|
||||
model.gen_state["halt"] = False
|
||||
|
||||
if utils.koboldai_vars.standalone:
|
||||
return scores
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||
utils.koboldai_vars.lua_state.table()
|
||||
)
|
||||
for r, row in enumerate(scores_list):
|
||||
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||
r + 1
|
||||
] = utils.koboldai_vars.lua_state.table(*row)
|
||||
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||
|
||||
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||
|
||||
scores = torch.tensor(
|
||||
tuple(
|
||||
tuple(row.values())
|
||||
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||
),
|
||||
device=scores.device,
|
||||
dtype=scores.dtype,
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class PhraseBiasLogitsProcessor:
|
||||
def __init__(self) -> None:
|
||||
# Hack
|
||||
self.model = None
|
||||
|
||||
def _find_intersection(self, big: List, small: List) -> int:
|
||||
"""Find the maximum overlap between the beginning of small and the end of big.
|
||||
Return the index of the token in small following the overlap, or 0.
|
||||
|
||||
big: The tokens in the context (as a tensor)
|
||||
small: The tokens in the phrase to bias (as a list)
|
||||
|
||||
Both big and small are in "oldest to newest" order.
|
||||
"""
|
||||
# There are asymptotically more efficient methods for determining the overlap,
|
||||
# but typically there will be few (0-1) instances of small[0] in the last len(small)
|
||||
# elements of big, plus small will typically be fairly short. So this naive
|
||||
# approach is acceptable despite O(N^2) worst case performance.
|
||||
|
||||
num_small = len(small)
|
||||
# The small list can only ever match against at most num_small tokens of big,
|
||||
# so create a slice. Typically, this slice will be as long as small, but it
|
||||
# may be shorter if the story has just started.
|
||||
# We need to convert the big slice to list, since natively big is a tensor
|
||||
# and tensor and list don't ever compare equal. It's better to convert here
|
||||
# and then use native equality tests than to iterate repeatedly later.
|
||||
big_slice = list(big[-num_small:])
|
||||
|
||||
# It's possible that the start token appears multiple times in small
|
||||
# For example, consider the phrase:
|
||||
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
|
||||
# If we merely look for the first instance of [ fair], then we would
|
||||
# generate the following output:
|
||||
# " fair is foul, and foul is fair is foul, and foul is fair..."
|
||||
start = small[0]
|
||||
for i, t in enumerate(big_slice):
|
||||
# Strictly unnecessary, but it's marginally faster to test the first
|
||||
# token before creating slices to test for a full match.
|
||||
if t == start:
|
||||
remaining = len(big_slice) - i
|
||||
if big_slice[i:] == small[:remaining]:
|
||||
# We found a match. If the small phrase has any remaining tokens
|
||||
# then return the index of the next token.
|
||||
if remaining < num_small:
|
||||
return remaining
|
||||
# In this case, the entire small phrase matched, so start over.
|
||||
return 0
|
||||
|
||||
# There were no matches, so just begin at the beginning.
|
||||
return 0
|
||||
|
||||
def _allow_leftwards_tampering(self, phrase: str) -> bool:
|
||||
"""Determines if a phrase should be tampered with from the left in
|
||||
the "soft" token encoding mode."""
|
||||
|
||||
if phrase[0] in [".", "?", "!", ";", ":", "\n"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_token_sequence(self, phrase: str) -> List[List]:
|
||||
"""Convert the phrase string into a list of encoded biases, each
|
||||
one being a list of tokens. How this is done is determined by the
|
||||
phrase's format:
|
||||
|
||||
- If the phrase is surrounded by square brackets ([]), the tokens
|
||||
will be the phrase split by commas (,). If a "token" isn't
|
||||
actually a number, it will be skipped. NOTE: Tokens output by
|
||||
this may not be in the model's vocabulary, and such tokens
|
||||
should be ignored later in the pipeline.
|
||||
- If the phrase is surrounded by curly brackets ({}), the phrase
|
||||
will be directly encoded with no synonym biases and no fancy
|
||||
tricks.
|
||||
- Otherwise, the phrase will be encoded, with close deviations
|
||||
being included as synonym biases.
|
||||
"""
|
||||
|
||||
# TODO: Cache these tokens, invalidate when model or bias is
|
||||
# changed.
|
||||
|
||||
# Handle direct token id input
|
||||
if phrase.startswith("[") and phrase.endswith("]"):
|
||||
no_brackets = phrase[1:-1]
|
||||
ret = []
|
||||
for token_id in no_brackets.split(","):
|
||||
try:
|
||||
ret.append(int(token_id))
|
||||
except ValueError:
|
||||
# Ignore non-numbers. Rascals!
|
||||
pass
|
||||
return [ret]
|
||||
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [self.model.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [self.model.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(self.model.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
def _get_biased_tokens(self, input_ids: List) -> Dict:
|
||||
# TODO: Different "bias slopes"?
|
||||
|
||||
ret = {}
|
||||
for phrase, _bias in utils.koboldai_vars.biases.items():
|
||||
bias_score, completion_threshold = _bias
|
||||
token_seqs = self._get_token_sequence(phrase)
|
||||
variant_deltas = {}
|
||||
for token_seq in token_seqs:
|
||||
bias_index = self._find_intersection(input_ids, token_seq)
|
||||
|
||||
# Ensure completion after completion_threshold tokens
|
||||
# Only provide a positive bias when the base bias score is positive.
|
||||
if bias_score > 0 and bias_index + 1 > completion_threshold:
|
||||
bias_score = 999
|
||||
|
||||
token_to_bias = token_seq[bias_index]
|
||||
variant_deltas[token_to_bias] = bias_score
|
||||
|
||||
# If multiple phrases bias the same token, add the modifiers
|
||||
# together. This should NOT be applied to automatic variants
|
||||
for token_to_bias, bias_score in variant_deltas.items():
|
||||
if token_to_bias in ret:
|
||||
ret[token_to_bias] += bias_score
|
||||
else:
|
||||
ret[token_to_bias] = bias_score
|
||||
return ret
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
model: inference_model.InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
input_ids: torch.longLongTensor,
|
||||
) -> torch.FloatTensor:
|
||||
self.model = model
|
||||
|
||||
assert scores.ndim == 2
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
scores_shape = scores.shape
|
||||
|
||||
for batch in range(scores_shape[0]):
|
||||
for token, bias in self._get_biased_tokens(input_ids[batch]).items():
|
||||
scores[batch][token] += bias
|
||||
|
||||
return scores
|
Reference in New Issue
Block a user