diff --git a/aiserver.py b/aiserver.py index d6675ebb..fbd4689e 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2017,29 +2017,50 @@ def patch_transformers(): def __init__(self): pass - def _rindex(self, lst: List, target) -> Optional[int]: - for index, item in enumerate(reversed(lst)): - if item == target: - return len(lst) - index - 1 - return None - def _find_intersection(self, big: List, small: List) -> int: - # Find the intersection of the end of "big" and the beginning of - # "small". A headache to think about, personally. Returns the index - # into "small" where the two stop intersecting. - start = self._rindex(big, small[0]) + """Find the maximum overlap between the beginning of small and the end of big. + Return the index of the token in small following the overlap, or 0. - # No progress into the token sequence, bias the first one. - if not start: - return 0 + big: The tokens in the context (as a tensor) + small: The tokens in the phrase to bias (as a list) - for i in range(len(small)): - try: - big_i = big[start + i] - except IndexError: - return i + Both big and small are in "oldest to newest" order. + """ + # There are asymptotically more efficient methods for determining the overlap, + # but typically there will be few (0-1) instances of small[0] in the last len(small) + # elements of big, plus small will typically be fairly short. So this naive + # approach is acceptable despite O(N^2) worst case performance. - # It's completed :^) + num_small = len(small) + # The small list can only ever match against at most num_small tokens of big, + # so create a slice. Typically, this slice will be as long as small, but it + # may be shorter if the story has just started. + # We need to convert the big slice to list, since natively big is a tensor + # and tensor and list don't ever compare equal. It's better to convert here + # and then use native equality tests than to iterate repeatedly later. + big_slice = list(big[-num_small:]) + + # It's possible that the start token appears multiple times in small + # For example, consider the phrase: + # [ fair is foul, and foul is fair, hover through the fog and filthy air] + # If we merely look for the first instance of [ fair], then we would + # generate the following output: + # " fair is foul, and foul is fair is foul, and foul is fair..." + start = small[0] + for i, t in enumerate(big_slice): + # Strictly unnecessary, but it's marginally faster to test the first + # token before creating slices to test for a full match. + if t == start: + remaining = len(big_slice) - i + if big_slice[i:] == small[:remaining]: + # We found a match. If the small phrase has any remaining tokens + # then return the index of the next token. + if remaining < num_small: + return remaining + # In this case, the entire small phrase matched, so start over. + return 0 + + # There were no matches, so just begin at the beginning. return 0 def _get_biased_tokens(self, input_ids: List) -> Dict: @@ -2054,11 +2075,16 @@ def patch_transformers(): bias_index = self._find_intersection(input_ids, token_seq) # Ensure completion after completion_threshold tokens - if bias_index + 1 > completion_threshold: + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: bias_score = 999 token_to_bias = token_seq[bias_index] - ret[token_to_bias] = bias_score + # If multiple phrases bias the same token, add the modifiers together. + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score return ret def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -2141,6 +2167,7 @@ def patch_transformers(): def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) processors.insert(0, LuaLogitsProcessor()) + processors.append(PhraseBiasLogitsProcessor()) processors.append(ProbabilityVisualizerLogitsProcessor()) return processors new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor diff --git a/templates/templates.html b/templates/templates.html index fd4ed113..5dcd56b7 100644 --- a/templates/templates.html +++ b/templates/templates.html @@ -78,13 +78,13 @@