Merge branch 'ebolam:UI2' into master

This commit is contained in:
Viningr
2022-11-01 23:29:42 +10:00
committed by GitHub
2 changed files with 52 additions and 25 deletions

View File

@@ -2017,29 +2017,50 @@ def patch_transformers():
def __init__(self):
pass
def _rindex(self, lst: List, target) -> Optional[int]:
for index, item in enumerate(reversed(lst)):
if item == target:
return len(lst) - index - 1
return None
def _find_intersection(self, big: List, small: List) -> int:
# Find the intersection of the end of "big" and the beginning of
# "small". A headache to think about, personally. Returns the index
# into "small" where the two stop intersecting.
start = self._rindex(big, small[0])
"""Find the maximum overlap between the beginning of small and the end of big.
Return the index of the token in small following the overlap, or 0.
# No progress into the token sequence, bias the first one.
if not start:
return 0
big: The tokens in the context (as a tensor)
small: The tokens in the phrase to bias (as a list)
for i in range(len(small)):
try:
big_i = big[start + i]
except IndexError:
return i
Both big and small are in "oldest to newest" order.
"""
# There are asymptotically more efficient methods for determining the overlap,
# but typically there will be few (0-1) instances of small[0] in the last len(small)
# elements of big, plus small will typically be fairly short. So this naive
# approach is acceptable despite O(N^2) worst case performance.
# It's completed :^)
num_small = len(small)
# The small list can only ever match against at most num_small tokens of big,
# so create a slice. Typically, this slice will be as long as small, but it
# may be shorter if the story has just started.
# We need to convert the big slice to list, since natively big is a tensor
# and tensor and list don't ever compare equal. It's better to convert here
# and then use native equality tests than to iterate repeatedly later.
big_slice = list(big[-num_small:])
# It's possible that the start token appears multiple times in small
# For example, consider the phrase:
# [ fair is foul, and foul is fair, hover through the fog and filthy air]
# If we merely look for the first instance of [ fair], then we would
# generate the following output:
# " fair is foul, and foul is fair is foul, and foul is fair..."
start = small[0]
for i, t in enumerate(big_slice):
# Strictly unnecessary, but it's marginally faster to test the first
# token before creating slices to test for a full match.
if t == start:
remaining = len(big_slice) - i
if big_slice[i:] == small[:remaining]:
# We found a match. If the small phrase has any remaining tokens
# then return the index of the next token.
if remaining < num_small:
return remaining
# In this case, the entire small phrase matched, so start over.
return 0
# There were no matches, so just begin at the beginning.
return 0
def _get_biased_tokens(self, input_ids: List) -> Dict:
@@ -2054,11 +2075,16 @@ def patch_transformers():
bias_index = self._find_intersection(input_ids, token_seq)
# Ensure completion after completion_threshold tokens
if bias_index + 1 > completion_threshold:
# Only provide a positive bias when the base bias score is positive.
if bias_score > 0 and bias_index + 1 > completion_threshold:
bias_score = 999
token_to_bias = token_seq[bias_index]
ret[token_to_bias] = bias_score
# If multiple phrases bias the same token, add the modifiers together.
if token_to_bias in ret:
ret[token_to_bias] += bias_score
else:
ret[token_to_bias] = bias_score
return ret
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -2141,6 +2167,7 @@ def patch_transformers():
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
processors.insert(0, LuaLogitsProcessor())
processors.append(PhraseBiasLogitsProcessor())
processors.append(ProbabilityVisualizerLogitsProcessor())
return processors
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor

View File

@@ -78,13 +78,13 @@
<div class="bias_score">
<div class="bias_slider">
<div class="bias_slider_bar">
<input type="range" min="-12" max="12" step="0.01" value="0" class="setting_item_input"
<input type="range" min="-50" max="50" step="0.01" value="0" class="setting_item_input"
oninput="update_bias_slider_value(this);"
onchange="save_bias(this);"/>
</div>
<div class="bias_slider_min">-12</div>
<div class="bias_slider_min">-50</div>
<div class="bias_slider_cur">0</div>
<div class="bias_slider_max">12</div>
<div class="bias_slider_max">50</div>
</div>
</div>
<div class="bias_comp_threshold">
@@ -99,4 +99,4 @@
<div class="bias_slider_max">10</div>
</div>
</div>
</div>
</div>