diff --git a/aiserver.py b/aiserver.py index 11fe3664..df8af18b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -2218,28 +2218,90 @@ def patch_transformers(): # There were no matches, so just begin at the beginning. return 0 + def _allow_leftwards_tampering(self, phrase: str) -> bool: + """Determines if a phrase should be tampered with from the left in + the "soft" token encoding mode.""" + + if phrase[0] in [".", "?", "!", ";", ":", "\n"]: + return False + return True + + def _get_token_sequence(self, phrase: str) -> List[List]: + """Convert the phrase string into a list of encoded biases, each + one being a list of tokens. How this is done is determined by the + phrase's format: + + - If the phrase is surrounded by square brackets ([]), the tokens + will be the phrase split by commas (,). If a "token" isn't + actually a number, it will be skipped. NOTE: Tokens output by + this may not be in the model's vocabulary, and such tokens + should be ignored later in the pipeline. + - If the phrase is surrounded by curly brackets ({}), the phrase + will be directly encoded with no synonym biases and no fancy + tricks. + - Otherwise, the phrase will be encoded, with close deviations + being included as synonym biases. + """ + + # TODO: Cache these tokens, invalidate when model or bias is + # changed. + + # Handle direct token id input + if phrase.startswith("[") and phrase.endswith("]"): + no_brackets = phrase[1:-1] + ret = [] + for token_id in no_brackets.split(","): + try: + ret.append(int(token_id)) + except ValueError: + # Ignore non-numbers. Rascals! + pass + return [ret] + + # Handle direct phrases + if phrase.startswith("{") and phrase.endswith("}"): + no_brackets = phrase[1:-1] + return [tokenizer.encode(no_brackets)] + + # Handle untamperable phrases + if not self._allow_leftwards_tampering(phrase): + return [tokenizer.encode(phrase)] + + # Handle slight alterations to original phrase + phrase = phrase.strip(" ") + ret = [] + + for alt_phrase in [phrase, f" {phrase}"]: + ret.append(tokenizer.encode(alt_phrase)) + + return ret + def _get_biased_tokens(self, input_ids: List) -> Dict: # TODO: Different "bias slopes"? ret = {} for phrase, _bias in koboldai_vars.biases.items(): bias_score, completion_threshold = _bias - # TODO: Cache these tokens, invalidate when model or bias is - # changed. - token_seq = tokenizer.encode(phrase) - bias_index = self._find_intersection(input_ids, token_seq) + token_seqs = self._get_token_sequence(phrase) + variant_deltas = {} + for token_seq in token_seqs: + bias_index = self._find_intersection(input_ids, token_seq) - # Ensure completion after completion_threshold tokens - # Only provide a positive bias when the base bias score is positive. - if bias_score > 0 and bias_index + 1 > completion_threshold: - bias_score = 999 + # Ensure completion after completion_threshold tokens + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: + bias_score = 999 - token_to_bias = token_seq[bias_index] - # If multiple phrases bias the same token, add the modifiers together. - if token_to_bias in ret: - ret[token_to_bias] += bias_score - else: - ret[token_to_bias] = bias_score + token_to_bias = token_seq[bias_index] + variant_deltas[token_to_bias] = bias_score + + # If multiple phrases bias the same token, add the modifiers + # together. This should NOT be applied to automatic variants + for token_to_bias, bias_score in variant_deltas.items(): + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score return ret def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: