From ffe85ce8a1aaf12744778cf8f9a5defa66080a7b Mon Sep 17 00:00:00 2001 From: somebody Date: Fri, 17 Mar 2023 16:56:47 -0500 Subject: [PATCH] Modeling: Fix logits processors (probs, biasing, lua) --- modeling/inference_model.py | 6 + modeling/inference_models/hf_torch.py | 274 +------------------------- modeling/logits_processors.py | 272 +++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 266 deletions(-) create mode 100644 modeling/logits_processors.py diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 010b9ddd..047d505d 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -13,6 +13,7 @@ from transformers import ( AutoTokenizer, ) from modeling.tokenizer import GenericTokenizer +from modeling import logits_processors import utils @@ -160,6 +161,11 @@ class InferenceModel: self.gen_state = {} self.post_token_hooks = [] self.stopper_hooks = [] + self.logits_processors = [ + logits_processors.LuaIntegration(), + logits_processors.PhraseBiasLogitsProcessor(), + logits_processors.ProbabilityVisualization(), + ] self.tokenizer = None self.capabilties = ModelCapabilities() diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index 5707dd9a..376f5352 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -36,7 +36,6 @@ from modeling.inference_models.hf import HFInferenceModel from modeling.inference_model import ( GenerationResult, GenerationSettings, - InferenceModel, ModelCapabilities, use_core_manipulations, ) @@ -104,6 +103,8 @@ class HFTorchInferenceModel(HFInferenceModel): else: scores = warper.torch(scores) + assert scores is not None, f"Scores are None; warper '{warper}' is to blame" + if LOG_SAMPLER_NO_EFFECT: if torch.equal(pre, scores): logger.info(warper, "had no effect on the scores.") @@ -166,274 +167,10 @@ class HFTorchInferenceModel(HFInferenceModel): # Patch logitswarpers - class PhraseBiasLogitsProcessor(LogitsProcessor): - def __init__(self): - pass - - def _find_intersection(self, big: List, small: List) -> int: - """Find the maximum overlap between the beginning of small and the end of big. - Return the index of the token in small following the overlap, or 0. - - big: The tokens in the context (as a tensor) - small: The tokens in the phrase to bias (as a list) - - Both big and small are in "oldest to newest" order. - """ - # There are asymptotically more efficient methods for determining the overlap, - # but typically there will be few (0-1) instances of small[0] in the last len(small) - # elements of big, plus small will typically be fairly short. So this naive - # approach is acceptable despite O(N^2) worst case performance. - - num_small = len(small) - # The small list can only ever match against at most num_small tokens of big, - # so create a slice. Typically, this slice will be as long as small, but it - # may be shorter if the story has just started. - # We need to convert the big slice to list, since natively big is a tensor - # and tensor and list don't ever compare equal. It's better to convert here - # and then use native equality tests than to iterate repeatedly later. - big_slice = list(big[-num_small:]) - - # It's possible that the start token appears multiple times in small - # For example, consider the phrase: - # [ fair is foul, and foul is fair, hover through the fog and filthy air] - # If we merely look for the first instance of [ fair], then we would - # generate the following output: - # " fair is foul, and foul is fair is foul, and foul is fair..." - start = small[0] - for i, t in enumerate(big_slice): - # Strictly unnecessary, but it's marginally faster to test the first - # token before creating slices to test for a full match. - if t == start: - remaining = len(big_slice) - i - if big_slice[i:] == small[:remaining]: - # We found a match. If the small phrase has any remaining tokens - # then return the index of the next token. - if remaining < num_small: - return remaining - # In this case, the entire small phrase matched, so start over. - return 0 - - # There were no matches, so just begin at the beginning. - return 0 - - def _allow_leftwards_tampering(self, phrase: str) -> bool: - """Determines if a phrase should be tampered with from the left in - the "soft" token encoding mode.""" - - if phrase[0] in [".", "?", "!", ";", ":", "\n"]: - return False - return True - - def _get_token_sequence(self, phrase: str) -> List[List]: - """Convert the phrase string into a list of encoded biases, each - one being a list of tokens. How this is done is determined by the - phrase's format: - - - If the phrase is surrounded by square brackets ([]), the tokens - will be the phrase split by commas (,). If a "token" isn't - actually a number, it will be skipped. NOTE: Tokens output by - this may not be in the model's vocabulary, and such tokens - should be ignored later in the pipeline. - - If the phrase is surrounded by curly brackets ({}), the phrase - will be directly encoded with no synonym biases and no fancy - tricks. - - Otherwise, the phrase will be encoded, with close deviations - being included as synonym biases. - """ - - # TODO: Cache these tokens, invalidate when model or bias is - # changed. - - # Handle direct token id input - if phrase.startswith("[") and phrase.endswith("]"): - no_brackets = phrase[1:-1] - ret = [] - for token_id in no_brackets.split(","): - try: - ret.append(int(token_id)) - except ValueError: - # Ignore non-numbers. Rascals! - pass - return [ret] - - # Handle direct phrases - if phrase.startswith("{") and phrase.endswith("}"): - no_brackets = phrase[1:-1] - return [m_self.tokenizer.encode(no_brackets)] - - # Handle untamperable phrases - if not self._allow_leftwards_tampering(phrase): - return [m_self.tokenizer.encode(phrase)] - - # Handle slight alterations to original phrase - phrase = phrase.strip(" ") - ret = [] - - for alt_phrase in [phrase, f" {phrase}"]: - ret.append(m_self.tokenizer.encode(alt_phrase)) - - return ret - - def _get_biased_tokens(self, input_ids: List) -> Dict: - # TODO: Different "bias slopes"? - - ret = {} - for phrase, _bias in utils.koboldai_vars.biases.items(): - bias_score, completion_threshold = _bias - token_seqs = self._get_token_sequence(phrase) - variant_deltas = {} - for token_seq in token_seqs: - bias_index = self._find_intersection(input_ids, token_seq) - - # Ensure completion after completion_threshold tokens - # Only provide a positive bias when the base bias score is positive. - if bias_score > 0 and bias_index + 1 > completion_threshold: - bias_score = 999 - - token_to_bias = token_seq[bias_index] - variant_deltas[token_to_bias] = bias_score - - # If multiple phrases bias the same token, add the modifiers - # together. This should NOT be applied to automatic variants - for token_to_bias, bias_score in variant_deltas.items(): - if token_to_bias in ret: - ret[token_to_bias] += bias_score - else: - ret[token_to_bias] = bias_score - return ret - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 - - scores_shape = scores.shape - - for batch in range(scores_shape[0]): - for token, bias in self._get_biased_tokens( - input_ids[batch] - ).items(): - scores[batch][token] += bias - - return scores - - class LuaLogitsProcessor(LogitsProcessor): - def __init__(self): - pass - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 - self.regeneration_required = False - self.halt = False - - if utils.koboldai_vars.standalone: - return scores - - scores_shape = scores.shape - scores_list = scores.tolist() - utils.koboldai_vars.lua_koboldbridge.logits = ( - utils.koboldai_vars.lua_state.table() - ) - for r, row in enumerate(scores_list): - utils.koboldai_vars.lua_koboldbridge.logits[ - r + 1 - ] = utils.koboldai_vars.lua_state.table(*row) - utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1] - - utils.koboldai_vars.lua_koboldbridge.execute_genmod() - - scores = torch.Tensor( - tuple( - tuple(row.values()) - for row in utils.koboldai_vars.lua_koboldbridge.logits.values() - ), - device=scores.device, - dtype=scores.dtype, - ) - assert scores.shape == scores_shape - - return scores - - from torch.nn import functional as F - - def visualize_probabilities( - model: InferenceModel, - scores: torch.FloatTensor, - ) -> None: - assert scores.ndim == 2 - - if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: - return - - if not utils.koboldai_vars.show_probs: - return scores - - option_offset = 0 - if ( - utils.koboldai_vars.actions.action_count + 1 - in utils.koboldai_vars.actions.actions - ): - for x in range( - len( - utils.koboldai_vars.actions.actions[ - utils.koboldai_vars.actions.action_count + 1 - ]["Options"] - ) - ): - option = utils.koboldai_vars.actions.actions[ - utils.koboldai_vars.actions.action_count + 1 - ]["Options"][x] - if ( - option["Pinned"] - or option["Previous Selection"] - or option["Edited"] - ): - option_offset = x + 1 - batch_offset = ( - int( - (utils.koboldai_vars.generated_tkns - 1) - / utils.koboldai_vars.genamt - ) - if utils.koboldai_vars.alt_multi_gen - else 0 - ) - for batch_index, batch in enumerate(scores): - probs = F.softmax(batch, dim=-1).cpu().numpy() - - token_prob_info = [] - for token_id, score in sorted( - enumerate(probs), key=lambda x: x[1], reverse=True - )[:8]: - token_prob_info.append( - { - "tokenId": token_id, - "decoded": utils.decodenewlines( - model.tokenizer.decode(token_id) - ), - "score": float(score), - } - ) - - if utils.koboldai_vars.numseqs == 1: - utils.koboldai_vars.actions.set_probabilities(token_prob_info) - else: - utils.koboldai_vars.actions.set_option_probabilities( - token_prob_info, batch_index + option_offset + batch_offset - ) - - return scores - def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: processors = new_get_logits_processor.old_get_logits_processor( *args, **kwargs ) - # TODOB4MERGE: These two - # processors.insert(0, LuaLogitsProcessor()) - # processors.append(PhraseBiasLogitsProcessor()) return processors use_core_manipulations.get_logits_processor = new_get_logits_processor @@ -453,7 +190,12 @@ class HFTorchInferenceModel(HFInferenceModel): **kwargs, ): scores = m_self._apply_warpers(scores=scores, input_ids=input_ids) - visualize_probabilities(m_self, scores) + + for processor in m_self.logits_processors: + scores = processor(m_self, scores=scores, input_ids=input_ids) + assert ( + scores is not None + ), f"Scores are None; processor '{processor}' is to blame" return scores def new_get_logits_warper( diff --git a/modeling/logits_processors.py b/modeling/logits_processors.py new file mode 100644 index 00000000..20a18026 --- /dev/null +++ b/modeling/logits_processors.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +from typing import Dict, List +import torch +from torch.nn import functional as F + +import utils + +# Weird annotations to avoid cyclic import +from modeling import inference_model + + +class ProbabilityVisualization: + def __call__( + self, + model: inference_model.InferenceModel, + scores: torch.FloatTensor, + input_ids: torch.longLongTensor, + ) -> torch.FloatTensor: + assert scores.ndim == 2 + + if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: + return scores + + if not utils.koboldai_vars.show_probs: + return scores + + option_offset = 0 + if ( + utils.koboldai_vars.actions.action_count + 1 + in utils.koboldai_vars.actions.actions + ): + for x in range( + len( + utils.koboldai_vars.actions.actions[ + utils.koboldai_vars.actions.action_count + 1 + ]["Options"] + ) + ): + option = utils.koboldai_vars.actions.actions[ + utils.koboldai_vars.actions.action_count + 1 + ]["Options"][x] + if option["Pinned"] or option["Previous Selection"] or option["Edited"]: + option_offset = x + 1 + batch_offset = ( + int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt) + if utils.koboldai_vars.alt_multi_gen + else 0 + ) + for batch_index, batch in enumerate(scores): + probs = F.softmax(batch, dim=-1).cpu().numpy() + + token_prob_info = [] + for token_id, score in sorted( + enumerate(probs), key=lambda x: x[1], reverse=True + )[:8]: + token_prob_info.append( + { + "tokenId": token_id, + "decoded": utils.decodenewlines( + model.tokenizer.decode(token_id) + ), + "score": float(score), + } + ) + + if utils.koboldai_vars.numseqs == 1: + utils.koboldai_vars.actions.set_probabilities(token_prob_info) + else: + utils.koboldai_vars.actions.set_option_probabilities( + token_prob_info, batch_index + option_offset + batch_offset + ) + + return scores + + +class LuaIntegration: + def __call__( + self, + model: inference_model.InferenceModel, + scores: torch.FloatTensor, + input_ids: torch.longLongTensor, + ) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + model.gen_state["regeneration_required"] = False + model.gen_state["halt"] = False + + if utils.koboldai_vars.standalone: + return scores + + scores_shape = scores.shape + scores_list = scores.tolist() + utils.koboldai_vars.lua_koboldbridge.logits = ( + utils.koboldai_vars.lua_state.table() + ) + for r, row in enumerate(scores_list): + utils.koboldai_vars.lua_koboldbridge.logits[ + r + 1 + ] = utils.koboldai_vars.lua_state.table(*row) + utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + utils.koboldai_vars.lua_koboldbridge.execute_genmod() + + scores = torch.tensor( + tuple( + tuple(row.values()) + for row in utils.koboldai_vars.lua_koboldbridge.logits.values() + ), + device=scores.device, + dtype=scores.dtype, + ) + assert scores.shape == scores_shape + + return scores + + +class PhraseBiasLogitsProcessor: + def __init__(self) -> None: + # Hack + self.model = None + + def _find_intersection(self, big: List, small: List) -> int: + """Find the maximum overlap between the beginning of small and the end of big. + Return the index of the token in small following the overlap, or 0. + + big: The tokens in the context (as a tensor) + small: The tokens in the phrase to bias (as a list) + + Both big and small are in "oldest to newest" order. + """ + # There are asymptotically more efficient methods for determining the overlap, + # but typically there will be few (0-1) instances of small[0] in the last len(small) + # elements of big, plus small will typically be fairly short. So this naive + # approach is acceptable despite O(N^2) worst case performance. + + num_small = len(small) + # The small list can only ever match against at most num_small tokens of big, + # so create a slice. Typically, this slice will be as long as small, but it + # may be shorter if the story has just started. + # We need to convert the big slice to list, since natively big is a tensor + # and tensor and list don't ever compare equal. It's better to convert here + # and then use native equality tests than to iterate repeatedly later. + big_slice = list(big[-num_small:]) + + # It's possible that the start token appears multiple times in small + # For example, consider the phrase: + # [ fair is foul, and foul is fair, hover through the fog and filthy air] + # If we merely look for the first instance of [ fair], then we would + # generate the following output: + # " fair is foul, and foul is fair is foul, and foul is fair..." + start = small[0] + for i, t in enumerate(big_slice): + # Strictly unnecessary, but it's marginally faster to test the first + # token before creating slices to test for a full match. + if t == start: + remaining = len(big_slice) - i + if big_slice[i:] == small[:remaining]: + # We found a match. If the small phrase has any remaining tokens + # then return the index of the next token. + if remaining < num_small: + return remaining + # In this case, the entire small phrase matched, so start over. + return 0 + + # There were no matches, so just begin at the beginning. + return 0 + + def _allow_leftwards_tampering(self, phrase: str) -> bool: + """Determines if a phrase should be tampered with from the left in + the "soft" token encoding mode.""" + + if phrase[0] in [".", "?", "!", ";", ":", "\n"]: + return False + return True + + def _get_token_sequence(self, phrase: str) -> List[List]: + """Convert the phrase string into a list of encoded biases, each + one being a list of tokens. How this is done is determined by the + phrase's format: + + - If the phrase is surrounded by square brackets ([]), the tokens + will be the phrase split by commas (,). If a "token" isn't + actually a number, it will be skipped. NOTE: Tokens output by + this may not be in the model's vocabulary, and such tokens + should be ignored later in the pipeline. + - If the phrase is surrounded by curly brackets ({}), the phrase + will be directly encoded with no synonym biases and no fancy + tricks. + - Otherwise, the phrase will be encoded, with close deviations + being included as synonym biases. + """ + + # TODO: Cache these tokens, invalidate when model or bias is + # changed. + + # Handle direct token id input + if phrase.startswith("[") and phrase.endswith("]"): + no_brackets = phrase[1:-1] + ret = [] + for token_id in no_brackets.split(","): + try: + ret.append(int(token_id)) + except ValueError: + # Ignore non-numbers. Rascals! + pass + return [ret] + + # Handle direct phrases + if phrase.startswith("{") and phrase.endswith("}"): + no_brackets = phrase[1:-1] + return [self.model.tokenizer.encode(no_brackets)] + + # Handle untamperable phrases + if not self._allow_leftwards_tampering(phrase): + return [self.model.tokenizer.encode(phrase)] + + # Handle slight alterations to original phrase + phrase = phrase.strip(" ") + ret = [] + + for alt_phrase in [phrase, f" {phrase}"]: + ret.append(self.model.tokenizer.encode(alt_phrase)) + + return ret + + def _get_biased_tokens(self, input_ids: List) -> Dict: + # TODO: Different "bias slopes"? + + ret = {} + for phrase, _bias in utils.koboldai_vars.biases.items(): + bias_score, completion_threshold = _bias + token_seqs = self._get_token_sequence(phrase) + variant_deltas = {} + for token_seq in token_seqs: + bias_index = self._find_intersection(input_ids, token_seq) + + # Ensure completion after completion_threshold tokens + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: + bias_score = 999 + + token_to_bias = token_seq[bias_index] + variant_deltas[token_to_bias] = bias_score + + # If multiple phrases bias the same token, add the modifiers + # together. This should NOT be applied to automatic variants + for token_to_bias, bias_score in variant_deltas.items(): + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score + return ret + + def __call__( + self, + model: inference_model.InferenceModel, + scores: torch.FloatTensor, + input_ids: torch.longLongTensor, + ) -> torch.FloatTensor: + self.model = model + + assert scores.ndim == 2 + assert input_ids.ndim == 2 + + scores_shape = scores.shape + + for batch in range(scores_shape[0]): + for token, bias in self._get_biased_tokens(input_ids[batch]).items(): + scores[batch][token] += bias + + return scores