diff --git a/model.py b/model.py index 4e250775..ceab364d 100644 --- a/model.py +++ b/model.py @@ -1,5 +1,4 @@ -# TODO: -# - Intertwine stoppers and streaming and such +# Before merge: please make sure to fix any TODOB4MERGE comments from __future__ import annotations import bisect @@ -16,12 +15,15 @@ import json import os import time import traceback -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import zipfile from tqdm.auto import tqdm from logger import logger import torch_lazy_loader +import warpers +from warpers import Warper + import torch from torch.nn import Embedding import numpy as np @@ -32,14 +34,13 @@ from transformers import ( GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, - GPTNeoModel, AutoModelForCausalLM, - AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel, modeling_utils, - AutoModelForTokenClassification, AutoConfig, + LogitsProcessorList, + LogitsProcessor, ) import utils @@ -399,21 +400,6 @@ def patch_transformers_generation() -> None: global transformers # Patch transformers to use our custom logit warpers -- Only HFTorchInferenceModel uses this - from transformers import ( - LogitsProcessorList, - LogitsWarper, - LogitsProcessor, - TopKLogitsWarper, - TopPLogitsWarper, - TemperatureLogitsWarper, - ) - from warpers import ( - AdvancedRepetitionPenaltyLogitsProcessor, - TailFreeLogitsWarper, - TypicalLogitsWarper, - TopALogitsWarper, - ) - def dynamic_processor_wrap(cls, field_name, var_name, cond=None): old_call = cls.__call__ @@ -434,343 +420,22 @@ def patch_transformers_generation() -> None: cls.__call__ = new_call # TODO: Make samplers generic - dynamic_processor_wrap( - AdvancedRepetitionPenaltyLogitsProcessor, - ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), - ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"), - cond=lambda x: x[0] != 1.0, - ) - dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) - dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) - dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) - dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) - dynamic_processor_wrap( - TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0 - ) - dynamic_processor_wrap( - TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0 - ) - - class PhraseBiasLogitsProcessor(LogitsProcessor): - def __init__(self): - pass - - def _find_intersection(self, big: List, small: List) -> int: - """Find the maximum overlap between the beginning of small and the end of big. - Return the index of the token in small following the overlap, or 0. - - big: The tokens in the context (as a tensor) - small: The tokens in the phrase to bias (as a list) - - Both big and small are in "oldest to newest" order. - """ - # There are asymptotically more efficient methods for determining the overlap, - # but typically there will be few (0-1) instances of small[0] in the last len(small) - # elements of big, plus small will typically be fairly short. So this naive - # approach is acceptable despite O(N^2) worst case performance. - - num_small = len(small) - # The small list can only ever match against at most num_small tokens of big, - # so create a slice. Typically, this slice will be as long as small, but it - # may be shorter if the story has just started. - # We need to convert the big slice to list, since natively big is a tensor - # and tensor and list don't ever compare equal. It's better to convert here - # and then use native equality tests than to iterate repeatedly later. - big_slice = list(big[-num_small:]) - - # It's possible that the start token appears multiple times in small - # For example, consider the phrase: - # [ fair is foul, and foul is fair, hover through the fog and filthy air] - # If we merely look for the first instance of [ fair], then we would - # generate the following output: - # " fair is foul, and foul is fair is foul, and foul is fair..." - start = small[0] - for i, t in enumerate(big_slice): - # Strictly unnecessary, but it's marginally faster to test the first - # token before creating slices to test for a full match. - if t == start: - remaining = len(big_slice) - i - if big_slice[i:] == small[:remaining]: - # We found a match. If the small phrase has any remaining tokens - # then return the index of the next token. - if remaining < num_small: - return remaining - # In this case, the entire small phrase matched, so start over. - return 0 - - # There were no matches, so just begin at the beginning. - return 0 - - def _allow_leftwards_tampering(self, phrase: str) -> bool: - """Determines if a phrase should be tampered with from the left in - the "soft" token encoding mode.""" - - if phrase[0] in [".", "?", "!", ";", ":", "\n"]: - return False - return True - - def _get_token_sequence(self, phrase: str) -> List[List]: - """Convert the phrase string into a list of encoded biases, each - one being a list of tokens. How this is done is determined by the - phrase's format: - - - If the phrase is surrounded by square brackets ([]), the tokens - will be the phrase split by commas (,). If a "token" isn't - actually a number, it will be skipped. NOTE: Tokens output by - this may not be in the model's vocabulary, and such tokens - should be ignored later in the pipeline. - - If the phrase is surrounded by curly brackets ({}), the phrase - will be directly encoded with no synonym biases and no fancy - tricks. - - Otherwise, the phrase will be encoded, with close deviations - being included as synonym biases. - """ - - # TODO: Cache these tokens, invalidate when model or bias is - # changed. - - # Handle direct token id input - if phrase.startswith("[") and phrase.endswith("]"): - no_brackets = phrase[1:-1] - ret = [] - for token_id in no_brackets.split(","): - try: - ret.append(int(token_id)) - except ValueError: - # Ignore non-numbers. Rascals! - pass - return [ret] - - # Handle direct phrases - if phrase.startswith("{") and phrase.endswith("}"): - no_brackets = phrase[1:-1] - return [HACK_currentmodel.tokenizer.encode(no_brackets)] - - # Handle untamperable phrases - if not self._allow_leftwards_tampering(phrase): - return [HACK_currentmodel.tokenizer.encode(phrase)] - - # Handle slight alterations to original phrase - phrase = phrase.strip(" ") - ret = [] - - for alt_phrase in [phrase, f" {phrase}"]: - ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase)) - - return ret - - def _get_biased_tokens(self, input_ids: List) -> Dict: - # TODO: Different "bias slopes"? - - ret = {} - for phrase, _bias in utils.koboldai_vars.biases.items(): - bias_score, completion_threshold = _bias - token_seqs = self._get_token_sequence(phrase) - variant_deltas = {} - for token_seq in token_seqs: - bias_index = self._find_intersection(input_ids, token_seq) - - # Ensure completion after completion_threshold tokens - # Only provide a positive bias when the base bias score is positive. - if bias_score > 0 and bias_index + 1 > completion_threshold: - bias_score = 999 - - token_to_bias = token_seq[bias_index] - variant_deltas[token_to_bias] = bias_score - - # If multiple phrases bias the same token, add the modifiers - # together. This should NOT be applied to automatic variants - for token_to_bias, bias_score in variant_deltas.items(): - if token_to_bias in ret: - ret[token_to_bias] += bias_score - else: - ret[token_to_bias] = bias_score - return ret - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 - - scores_shape = scores.shape - - for batch in range(scores_shape[0]): - for token, bias in self._get_biased_tokens(input_ids[batch]).items(): - scores[batch][token] += bias - - return scores - - class LuaLogitsProcessor(LogitsProcessor): - def __init__(self): - pass - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - assert scores.ndim == 2 - assert input_ids.ndim == 2 - self.regeneration_required = False - self.halt = False - - if utils.koboldai_vars.standalone: - return scores - - scores_shape = scores.shape - scores_list = scores.tolist() - utils.koboldai_vars.lua_koboldbridge.logits = ( - utils.koboldai_vars.lua_state.table() - ) - for r, row in enumerate(scores_list): - utils.koboldai_vars.lua_koboldbridge.logits[ - r + 1 - ] = utils.koboldai_vars.lua_state.table(*row) - utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1] - - utils.koboldai_vars.lua_koboldbridge.execute_genmod() - - scores = torch.tensor( - tuple( - tuple(row.values()) - for row in utils.koboldai_vars.lua_koboldbridge.logits.values() - ), - device=scores.device, - dtype=scores.dtype, - ) - assert scores.shape == scores_shape - - return scores - - from torch.nn import functional as F - - def visualize_probabilities( - model: InferenceModel, - scores: torch.FloatTensor, - ) -> None: - assert scores.ndim == 2 - - if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: - return - - if not utils.koboldai_vars.show_probs: - return scores - - option_offset = 0 - if ( - utils.koboldai_vars.actions.action_count + 1 - in utils.koboldai_vars.actions.actions - ): - for x in range( - len( - utils.koboldai_vars.actions.actions[ - utils.koboldai_vars.actions.action_count + 1 - ]["Options"] - ) - ): - option = utils.koboldai_vars.actions.actions[ - utils.koboldai_vars.actions.action_count + 1 - ]["Options"][x] - if option["Pinned"] or option["Previous Selection"] or option["Edited"]: - option_offset = x + 1 - batch_offset = ( - int((utils.koboldai_vars.generated_tkns - 1) / utils.koboldai_vars.genamt) - if utils.koboldai_vars.alt_multi_gen - else 0 - ) - for batch_index, batch in enumerate(scores): - probs = F.softmax(batch, dim=-1).cpu().numpy() - - token_prob_info = [] - for token_id, score in sorted( - enumerate(probs), key=lambda x: x[1], reverse=True - )[:8]: - token_prob_info.append( - { - "tokenId": token_id, - "decoded": utils.decodenewlines( - model.tokenizer.decode(token_id) - ), - "score": float(score), - } - ) - - if utils.koboldai_vars.numseqs == 1: - utils.koboldai_vars.actions.set_probabilities(token_prob_info) - else: - utils.koboldai_vars.actions.set_option_probabilities( - token_prob_info, batch_index + option_offset + batch_offset - ) - - return scores - - def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: - processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs) - processors.insert(0, LuaLogitsProcessor()) - processors.append(PhraseBiasLogitsProcessor()) - return processors - - use_core_manipulations.get_logits_processor = new_get_logits_processor - new_get_logits_processor.old_get_logits_processor = ( - transformers.GenerationMixin._get_logits_processor - ) - - class KoboldLogitsWarperList(LogitsProcessorList): - def __init__(self, beams: int = 1, **kwargs): - self.__warper_list: List[LogitsWarper] = [] - self.__warper_list.append( - TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)) - ) - self.__warper_list.append( - TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)) - ) - self.__warper_list.append( - TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)) - ) - self.__warper_list.append( - TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)) - ) - self.__warper_list.append( - TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)) - ) - self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5)) - self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor()) - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - *args, - **kwargs, - ): - sampler_order = utils.koboldai_vars.sampler_order[:] - if ( - len(sampler_order) < 7 - ): # Add repetition penalty at beginning if it's not present - sampler_order = [6] + sampler_order - for k in sampler_order: - scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) - visualize_probabilities(HACK_currentmodel, scores) - return scores - - def new_get_logits_warper( - beams: int = 1, - ) -> LogitsProcessorList: - return KoboldLogitsWarperList(beams=beams) - - def new_sample(self, *args, **kwargs): - assert kwargs.pop("logits_warper", None) is not None - kwargs["logits_warper"] = new_get_logits_warper( - beams=1, - ) - if (utils.koboldai_vars.newlinemode == "s") or ( - utils.koboldai_vars.newlinemode == "ns" - ): - kwargs["eos_token_id"] = -1 - kwargs.setdefault("pad_token_id", 2) - return new_sample.old_sample(self, *args, **kwargs) - - new_sample.old_sample = transformers.GenerationMixin.sample - use_core_manipulations.sample = new_sample + # dynamic_processor_wrap( + # AdvancedRepetitionPenaltyLogitsProcessor, + # ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), + # ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"), + # cond=lambda x: x[0] != 1.0, + # ) + # dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) + # dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) + # dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) + # dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) + # dynamic_processor_wrap( + # TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0 + # ) + # dynamic_processor_wrap( + # TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0 + # ) # Allow bad words filter to ban <|endoftext|> token import transformers.generation.logits_process @@ -852,10 +517,12 @@ class InferenceModel: global HACK_currentmodel HACK_currentmodel = self + print(self.raw_generate("Hi guys,", 20).__dict__) + def _post_load(self) -> None: pass - def _load(self, save_model: bool, inital_load: bool) -> None: + def _load(self, save_model: bool, initial_load: bool) -> None: raise NotImplementedError def _get_tokenizer(self, location: str): @@ -1520,8 +1187,22 @@ class HFTorchInferenceModel(InferenceModel): ) self._old_stopping_criteria = None + def _apply_warpers( + self, scores: torch.Tensor, input_ids: torch.Tensor + ) -> torch.Tensor: + warpers.update_settings() + for sid in utils.koboldai_vars.sampler_order: + warper = Warper.from_id(sid) + if warper == warpers.RepetitionPenalty: + # Rep pen needs more data than other samplers + print("is rep:", warper) + scores = warper.torch(scores, input_ids=input_ids) + else: + print("aint rep:", warper) + scores = warper.torch(scores) + return scores + def _post_load(self) -> None: - print("HELLLOOOOOOOOOOOOOOOOOOOOOOOOOOO") # Patch stopping_criteria class PTHStopper(StoppingCriteria): @@ -1551,6 +1232,323 @@ class HFTorchInferenceModel(InferenceModel): use_core_manipulations.get_stopping_criteria = _get_stopping_criteria + # Patch logitswarpers + + class PhraseBiasLogitsProcessor(LogitsProcessor): + def __init__(self): + pass + + def _find_intersection(self, big: List, small: List) -> int: + """Find the maximum overlap between the beginning of small and the end of big. + Return the index of the token in small following the overlap, or 0. + + big: The tokens in the context (as a tensor) + small: The tokens in the phrase to bias (as a list) + + Both big and small are in "oldest to newest" order. + """ + # There are asymptotically more efficient methods for determining the overlap, + # but typically there will be few (0-1) instances of small[0] in the last len(small) + # elements of big, plus small will typically be fairly short. So this naive + # approach is acceptable despite O(N^2) worst case performance. + + num_small = len(small) + # The small list can only ever match against at most num_small tokens of big, + # so create a slice. Typically, this slice will be as long as small, but it + # may be shorter if the story has just started. + # We need to convert the big slice to list, since natively big is a tensor + # and tensor and list don't ever compare equal. It's better to convert here + # and then use native equality tests than to iterate repeatedly later. + big_slice = list(big[-num_small:]) + + # It's possible that the start token appears multiple times in small + # For example, consider the phrase: + # [ fair is foul, and foul is fair, hover through the fog and filthy air] + # If we merely look for the first instance of [ fair], then we would + # generate the following output: + # " fair is foul, and foul is fair is foul, and foul is fair..." + start = small[0] + for i, t in enumerate(big_slice): + # Strictly unnecessary, but it's marginally faster to test the first + # token before creating slices to test for a full match. + if t == start: + remaining = len(big_slice) - i + if big_slice[i:] == small[:remaining]: + # We found a match. If the small phrase has any remaining tokens + # then return the index of the next token. + if remaining < num_small: + return remaining + # In this case, the entire small phrase matched, so start over. + return 0 + + # There were no matches, so just begin at the beginning. + return 0 + + def _allow_leftwards_tampering(self, phrase: str) -> bool: + """Determines if a phrase should be tampered with from the left in + the "soft" token encoding mode.""" + + if phrase[0] in [".", "?", "!", ";", ":", "\n"]: + return False + return True + + def _get_token_sequence(self, phrase: str) -> List[List]: + """Convert the phrase string into a list of encoded biases, each + one being a list of tokens. How this is done is determined by the + phrase's format: + + - If the phrase is surrounded by square brackets ([]), the tokens + will be the phrase split by commas (,). If a "token" isn't + actually a number, it will be skipped. NOTE: Tokens output by + this may not be in the model's vocabulary, and such tokens + should be ignored later in the pipeline. + - If the phrase is surrounded by curly brackets ({}), the phrase + will be directly encoded with no synonym biases and no fancy + tricks. + - Otherwise, the phrase will be encoded, with close deviations + being included as synonym biases. + """ + + # TODO: Cache these tokens, invalidate when model or bias is + # changed. + + # Handle direct token id input + if phrase.startswith("[") and phrase.endswith("]"): + no_brackets = phrase[1:-1] + ret = [] + for token_id in no_brackets.split(","): + try: + ret.append(int(token_id)) + except ValueError: + # Ignore non-numbers. Rascals! + pass + return [ret] + + # Handle direct phrases + if phrase.startswith("{") and phrase.endswith("}"): + no_brackets = phrase[1:-1] + return [HACK_currentmodel.tokenizer.encode(no_brackets)] + + # Handle untamperable phrases + if not self._allow_leftwards_tampering(phrase): + return [HACK_currentmodel.tokenizer.encode(phrase)] + + # Handle slight alterations to original phrase + phrase = phrase.strip(" ") + ret = [] + + for alt_phrase in [phrase, f" {phrase}"]: + ret.append(HACK_currentmodel.tokenizer.encode(alt_phrase)) + + return ret + + def _get_biased_tokens(self, input_ids: List) -> Dict: + # TODO: Different "bias slopes"? + + ret = {} + for phrase, _bias in utils.koboldai_vars.biases.items(): + bias_score, completion_threshold = _bias + token_seqs = self._get_token_sequence(phrase) + variant_deltas = {} + for token_seq in token_seqs: + bias_index = self._find_intersection(input_ids, token_seq) + + # Ensure completion after completion_threshold tokens + # Only provide a positive bias when the base bias score is positive. + if bias_score > 0 and bias_index + 1 > completion_threshold: + bias_score = 999 + + token_to_bias = token_seq[bias_index] + variant_deltas[token_to_bias] = bias_score + + # If multiple phrases bias the same token, add the modifiers + # together. This should NOT be applied to automatic variants + for token_to_bias, bias_score in variant_deltas.items(): + if token_to_bias in ret: + ret[token_to_bias] += bias_score + else: + ret[token_to_bias] = bias_score + return ret + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + + scores_shape = scores.shape + + for batch in range(scores_shape[0]): + for token, bias in self._get_biased_tokens( + input_ids[batch] + ).items(): + scores[batch][token] += bias + + return scores + + class LuaLogitsProcessor(LogitsProcessor): + def __init__(self): + pass + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + assert scores.ndim == 2 + assert input_ids.ndim == 2 + self.regeneration_required = False + self.halt = False + + if utils.koboldai_vars.standalone: + return scores + + scores_shape = scores.shape + scores_list = scores.tolist() + utils.koboldai_vars.lua_koboldbridge.logits = ( + utils.koboldai_vars.lua_state.table() + ) + for r, row in enumerate(scores_list): + utils.koboldai_vars.lua_koboldbridge.logits[ + r + 1 + ] = utils.koboldai_vars.lua_state.table(*row) + utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1] + + utils.koboldai_vars.lua_koboldbridge.execute_genmod() + + scores = torch.tensor( + tuple( + tuple(row.values()) + for row in utils.koboldai_vars.lua_koboldbridge.logits.values() + ), + device=scores.device, + dtype=scores.dtype, + ) + assert scores.shape == scores_shape + + return scores + + from torch.nn import functional as F + + def visualize_probabilities( + model: InferenceModel, + scores: torch.FloatTensor, + ) -> None: + assert scores.ndim == 2 + + if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: + return + + if not utils.koboldai_vars.show_probs: + return scores + + option_offset = 0 + if ( + utils.koboldai_vars.actions.action_count + 1 + in utils.koboldai_vars.actions.actions + ): + for x in range( + len( + utils.koboldai_vars.actions.actions[ + utils.koboldai_vars.actions.action_count + 1 + ]["Options"] + ) + ): + option = utils.koboldai_vars.actions.actions[ + utils.koboldai_vars.actions.action_count + 1 + ]["Options"][x] + if ( + option["Pinned"] + or option["Previous Selection"] + or option["Edited"] + ): + option_offset = x + 1 + batch_offset = ( + int( + (utils.koboldai_vars.generated_tkns - 1) + / utils.koboldai_vars.genamt + ) + if utils.koboldai_vars.alt_multi_gen + else 0 + ) + for batch_index, batch in enumerate(scores): + probs = F.softmax(batch, dim=-1).cpu().numpy() + + token_prob_info = [] + for token_id, score in sorted( + enumerate(probs), key=lambda x: x[1], reverse=True + )[:8]: + token_prob_info.append( + { + "tokenId": token_id, + "decoded": utils.decodenewlines( + model.tokenizer.decode(token_id) + ), + "score": float(score), + } + ) + + if utils.koboldai_vars.numseqs == 1: + utils.koboldai_vars.actions.set_probabilities(token_prob_info) + else: + utils.koboldai_vars.actions.set_option_probabilities( + token_prob_info, batch_index + option_offset + batch_offset + ) + + return scores + + def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList: + processors = new_get_logits_processor.old_get_logits_processor( + *args, **kwargs + ) + # TODOB4MERGE: These two + # processors.insert(0, LuaLogitsProcessor()) + # processors.append(PhraseBiasLogitsProcessor()) + return processors + + use_core_manipulations.get_logits_processor = new_get_logits_processor + new_get_logits_processor.old_get_logits_processor = ( + transformers.GenerationMixin._get_logits_processor + ) + + class KoboldLogitsWarperList(LogitsProcessorList): + def __init__(self): + pass + + def __call__( + lw_self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + *args, + **kwargs, + ): + # sampler_order = utils.koboldai_vars.sampler_order[:] + # if ( + # len(sampler_order) < 7 + # ): # Add repetition penalty at beginning if it's not present + # sampler_order = [6] + sampler_order + # for k in sampler_order: + # scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) + scores = self._apply_warpers(scores=scores, input_ids=input_ids) + visualize_probabilities(HACK_currentmodel, scores) + return scores + + def new_get_logits_warper( + beams: int = 1, + ) -> LogitsProcessorList: + return KoboldLogitsWarperList() + + def new_sample(self, *args, **kwargs): + assert kwargs.pop("logits_warper", None) is not None + kwargs["logits_warper"] = new_get_logits_warper( + beams=1, + ) + if utils.koboldai_vars.newlinemode in ["s", "ns"]: + kwargs["eos_token_id"] = -1 + kwargs.setdefault("pad_token_id", 2) + return new_sample.old_sample(self, *args, **kwargs) + + new_sample.old_sample = transformers.GenerationMixin.sample + use_core_manipulations.sample = new_sample + def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 3046a2e6..855413a2 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -349,57 +349,6 @@ def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarra # probability distribution) return jax.random.categorical(key, logits, -1).astype(np.uint32) -def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange): - ''' - This gets called by generate_loop_fn to apply repetition penalty - to the 1D array logits using the provided 1D array of tokens to penalize - ''' - rpslope = jnp.int32(rpslope) - rprange = jnp.int32(rprange) - clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange) - penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1) - # Make a new array with the same length as the tokens array but with - # each element replaced by the value at the corresponding index in the - # logits array; e.g. - # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], - # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] - penalty_logits = jnp.take(logits, tokens) - # Repetition penalty slope - def apply_slope(carry): - repetition_penalty, rprange = carry - _penalty = (penalty_arange/(rprange - 1)) * 2 - 1 - _penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1)) - _penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1) - return _penalty - repetition_penalty = jax.lax.cond( - (rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash - apply_slope, - lambda carry: jnp.full(tokens.shape, carry[0]), - (repetition_penalty, rprange), - ) - # Divide positive values by repetition_penalty and multiply negative - # values by repetition_penalty (the academic publication that described - # this technique actually just only divided, but that would cause tokens - # with negative logits to become more likely, which is obviously wrong) - if koboldai_vars.use_alt_rep_pen: - penalty_logits = jnp.where( - penalty_arange >= 0, - penalty_logits - jnp.log(repetition_penalty), - penalty_logits, - ) - else: - penalty_logits = jnp.where( - penalty_arange >= 0, - jnp.where( - penalty_logits > 0, - penalty_logits/repetition_penalty, - penalty_logits*repetition_penalty, - ), - penalty_logits, - ) - # Finally, put those penalized logit values back into their original - # positions in the logits array - return logits.at[tokens].set(penalty_logits) def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0): ''' diff --git a/warpers.py b/warpers.py index 0adf5740..f39e0b69 100644 --- a/warpers.py +++ b/warpers.py @@ -38,11 +38,31 @@ file are mostly porting warper code to the torch methods. from __future__ import annotations +import utils import torch -import jax -import jax.numpy as jnp import numpy as np -import tpu_mtj_backend + +try: + import jax + import jax.numpy as jnp + import tpu_mtj_backend +except ImportError: + assert not utils.koboldai_vars.use_colab_tpu + + +def update_settings(): + # This feels like a bad way to structure this + koboldai_vars = utils.koboldai_vars + Temperature.temperature = koboldai_vars.temp + TopP.top_p = koboldai_vars.top_p + TopK.top_k = koboldai_vars.top_k + TopA.top_a = koboldai_vars.top_a + TailFree.tfs = koboldai_vars.tfs + Typical.typical = koboldai_vars.typical + RepetitionPenalty.rep_pen = koboldai_vars.rep_pen + RepetitionPenalty.rep_pen_range = koboldai_vars.rep_pen_range + RepetitionPenalty.rep_pen_slope = koboldai_vars.rep_pen_slope + RepetitionPenalty.use_alt_rep_pen = koboldai_vars.use_alt_rep_pen class Warper: @@ -70,7 +90,7 @@ class Temperature(Warper): @classmethod def jax(cls, scores: jnp.array) -> jnp.array: - return scores / cls.value + return scores / cls.temperature class TopP(Warper): @@ -88,7 +108,7 @@ class TopP(Warper): cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= (1 - cls.value) + sorted_indices_to_remove = cumulative_probs <= (1 - cls.top_p) # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( @@ -109,7 +129,7 @@ class TopP(Warper): cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) # We want to remove tokens with cumulative probability higher # than top_p - sorted_indices_to_remove = cumulative_probabilities > cls.value + sorted_indices_to_remove = cumulative_probabilities > cls.top_p # Don't ever remove the token with the highest logit, even if # the probability is higher than top_p sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) @@ -358,7 +378,7 @@ class RepetitionPenalty(Warper): use_alt_rep_pen: bool = False @classmethod - def torch(cls, scores: torch.Tensor) -> torch.Tensor: + def torch(cls, scores: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: cls.rep_pen_range = int(cls.rep_pen_range) clipped_penalty_range = min(input_ids.shape[-1], cls.rep_pen_range)