From 77813df2f0745527459c9a23393214a3c15f64d4 Mon Sep 17 00:00:00 2001 From: Llama <34464159+pi6am@users.noreply.github.com> Date: Tue, 1 Nov 2022 00:20:15 -0700 Subject: [PATCH] Re-enable and fix several issues with phrase bias Add the PhraseBiasLogitsProcessor to the logits processor list Fix an issue with bias phrases that contain the start token multiple times. Because we were searching backwards for the first occurrence of the start token, we would restart the phrase when we encountered a subsequent instance of the token. We now search forwards from the maximum possible overlap to find the maximum overlap. Fix an issue with the phrase bias token index not accounting for non-matching words. Previously, once we found the start token, we would apply the bias for each token in the bias phrase even if subsequent tokens in the context didn't match the bias phrase. Do not apply phrase completion if the bias score is negative. If multiple phrases apply a score modifier to the same token, add the scores rather than replacing the modifier with the last occurrence. Increase the maximum range of the bias slider. For extremely repetitive text on large models, -12 is insufficient to break the model out of its loop. -50 to 50 is potentially excessive, but it's safer to give the user some additional control over the bias score. --- aiserver.py | 69 ++++++++++++++++++++++++++++------------ templates/templates.html | 8 ++--- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/aiserver.py b/aiserver.py index 422cbed9..9748f8ea 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2017,29 +2017,50 @@ def patch_transformers(): def __init__(self): pass - def _rindex(self, lst: List, target) -> Optional[int]: - for index, item in enumerate(reversed(lst)): - if item == target: - return len(lst) - index - 1 - return None - def _find_intersection(self, big: List, small: List) -> int: - # Find the intersection of the end of "big" and the beginning of - # "small". A headache to think about, personally. Returns the index - # into "small" where the two stop intersecting. - start = self._rindex(big, small[0]) + """Find the maximum overlap between the beginning of small and the end of big. + Return the index of the token in small following the overlap, or 0. - # No progress into the token sequence, bias the first one. - if not start: - return 0 + big: The tokens in the context (as a tensor) + small: The tokens in the phrase to bias (as a list) - for i in range(len(small)): - try: - big_i = big[start + i] - except IndexError: - return i + Both big and small are in "oldest to newest" order. + """ + # There are asymptotically more efficient methods for determining the overlap, + # but typically there will be few (0-1) instances of small[0] in the last len(small) + # elements of big, plus small will typically be fairly short. So this naive + # approach is acceptable despite O(N^2) worst case performance. - # It's completed :^) + num_small = len(small) + # The small list can only ever match against at most num_small tokens of big, + # so create a slice. Typically, this slice will be as long as small, but it + # may be shorter if the story has just started. + # We need to convert the big slice to list, since natively big is a tensor + # and tensor and list don't ever compare equal. It's better to convert here + # and then use native equality tests than to iterate repeatedly later. + big_slice = list(big[-num_small:]) + + # It's possible that the start token appears multiple times in small + # For example, consider the phrase: + # [ fair is foul, and foul is fair, hover through the fog and filthy air] + # If we merely look for the first instance of [ fair], then we would + # generate the following output: + # " fair is foul, and foul is fair is foul, and foul is fair..." + start = small[0] + for i, t in enumerate(big_slice): + # Strictly unnecessary, but it's marginally faster to test the first + # token before creating slices to test for a full match. + if t == start: + remaining = len(big_slice) - i + if big_slice[i:] == small[:remaining]: + # We found a match. If the small phrase has any remaining tokens + # then return the index of the next token. + if remaining < num_small: + return remaining + # In this case, the entire small phrase matched, so start over. + return 0 + + # There were no matches, so just begin at the beginning. return 0 def _get_biased_tokens(self, input_ids: List) -> Dict: @@ -2054,11 +2075,16 @@ def patch_transformers(): bias_index = self._find_intersection(input_ids, token_seq) # Ensure completion after completion_threshold tokens - if bias_index + 1 > completion_threshold: + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: bias_score = 999 token_to_bias = token_seq[bias_index] - ret[token_to_bias] = bias_score + # If multiple phrases bias the same token, add the modifiers together. + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score return ret def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -2141,6 +2167,7 @@ def patch_transformers(): def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors.insert(0, LuaLogitsProcessor()) + processors.append(PhraseBiasLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor()) return processors new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor diff --git a/templates/templates.html b/templates/templates.html index fd4ed113..5dcd56b7 100644 --- a/templates/templates.html +++ b/templates/templates.html @@ -78,13 +78,13 @@