diff --git a/aiserver.py b/aiserver.py index 258bb109..2fd605e3 100644 --- a/aiserver.py +++ b/aiserver.py @@ -3498,7 +3498,7 @@ def apiactionsubmit_generate(txt, minimum, maximum): torch.cuda.empty_cache() # Submit input text to generator - _genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, set()) + _genout, already_generated = tpool.execute(model.core_generate, txt, set()) genout = [utils.applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout] @@ -3939,7 +3939,7 @@ def generate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: start_time = time.time() - genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, found_entries) + genout, already_generated = tpool.execute(model.core_generate, txt, found_entries) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) except Exception as e: if(issubclass(type(e), lupa.LuaError)): diff --git a/model.py b/model.py index bcc30cd2..37741585 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,7 @@ -# Before merge: please make sure to fix any TODOB4MERGE comments +# Before merge: +# - Fix Lua +# - Check if probabilities work +# - Fix any TODOB4MERGE comments from __future__ import annotations import bisect @@ -79,14 +82,20 @@ class OpenAIAPIError(Exception): class HordeException(Exception): + """To be used for errors on server side of the Horde.""" + pass class ColabException(Exception): + """To be used for errors when using the Colab API as an interface.""" + pass class APIException(Exception): + """To be used for errors when using the Kobold API as an interface.""" + pass @@ -424,24 +433,6 @@ def patch_transformers_generation() -> None: cls.__call__ = new_call - # TODO: Make samplers generic - # dynamic_processor_wrap( - # AdvancedRepetitionPenaltyLogitsProcessor, - # ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), - # ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"), - # cond=lambda x: x[0] != 1.0, - # ) - # dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0) - # dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0) - # dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0) - # dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0) - # dynamic_processor_wrap( - # TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0 - # ) - # dynamic_processor_wrap( - # TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0 - # ) - # Allow bad words filter to ban <|endoftext|> token import transformers.generation.logits_process @@ -463,6 +454,8 @@ def patch_transformers() -> None: class GenerationResult: + """A container for easily accessing different forms of model outputs. Returned by most generate functions.""" + def __init__( self, model: InferenceModel, @@ -506,6 +499,8 @@ class ModelCapabilities: class InferenceModel: + """Root class for all models.""" + def __init__(self) -> None: self.gen_state = {} self.post_token_hooks = [] @@ -514,7 +509,7 @@ class InferenceModel: self.capabilties = ModelCapabilities() def load(self, save_model: bool = False, initial_load: bool = False) -> None: - """Main load function. Do not override this. Override _load() instead.""" + """User-facing load function. Do not override this; try `_load()` instead.""" self._load(save_model=save_model, initial_load=initial_load) self._post_load() @@ -525,12 +520,23 @@ class InferenceModel: print(self.raw_generate("Hi guys,", 20).__dict__) def _post_load(self) -> None: + """Post load hook. Called after `_load()`.""" pass def _load(self, save_model: bool, initial_load: bool) -> None: + """Main load method. All logic related to loading the model onto the + selected device(s) and preparing it for inference should be implemented here.""" raise NotImplementedError - def _get_tokenizer(self, location: str): + def _get_tokenizer(self, location: str) -> AutoTokenizer: + """Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`. + + Args: + location (str): Either a local model directory path or a HuggingFace model ID. + + Returns: + AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer. + """ if utils.koboldai_vars.model_type == "xglm": # Default to newline mode if using XGLM utils.koboldai_vars.newlinemode = "s" @@ -565,13 +571,19 @@ class InferenceModel: def core_generate( self, text: list, - _min: int, - _max: int, found_entries: set, - is_core: bool = False, ): - # This generation function is tangled with koboldai_vars intentionally. It - # is meant for the story and nothing else. + """Generate story text. Heavily tied to story-specific parameters; if + you are making a new generation-based feature, consider `generate_raw()`. + + Args: + text (list): Encoded input tokens + found_entries (set): Entries found for Dynamic WI + + Raises: + RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check + RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check + """ start_time = time.time() gen_in = torch.tensor(text, dtype=torch.long)[None] @@ -804,6 +816,18 @@ class InferenceModel: single_line: bool = False, batch_count: int = 1, ) -> GenerationResult: + """Lowest level model-agnostic generation function. To be overridden by model implementation. + + Args: + prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs + max_new (int): Maximum amount of new tokens to generate + gen_settings (GenerationSettings): State to pass in single-generation setting overrides + single_line (bool, optional): Generate one line only. Defaults to False. + batch_count (int, optional): How big of a batch to generate. Defaults to 1. + + Returns: + GenerationResult: The model's output + """ raise NotImplementedError def raw_generate( @@ -820,7 +844,27 @@ class InferenceModel: single_line: bool = False, found_entries: set = (), ) -> GenerationResult: - """A wrapper around _raw_generate() that handles timing and some other minute stuff.""" + """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. + + Args: + prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs + max_new (int): Maximum amount of new tokens to generate + do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False. + do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False. + batch_count (int, optional): How big of a batch to generate. Defaults to 1. + bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False. + generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None + is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. + single_line (bool, optional): Generate one line only.. Defaults to False. + found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). + + Raises: + ValueError: If prompt type is weird + NotImplementedError: If model is ReadOnly + + Returns: + GenerationResult: The model's output + """ # TODO: Support singleline outside of torch self.gen_state["do_streaming"] = do_streaming @@ -1101,10 +1145,8 @@ class HFMTJInferenceModel(HFInferenceModel): initial_load=initial_load, logger=logger, **self.model_config.to_dict() - # **utils.koboldai_vars.modelconfig, ) - # tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig) utils.koboldai_vars.modeldim = int( tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]) ) diff --git a/warpers.py b/warpers.py index c63b34e2..0885842c 100644 --- a/warpers.py +++ b/warpers.py @@ -68,6 +68,25 @@ def update_settings(): class Warper: + """The backbone for implementing code which manipulates token logits. + All Warpers should be singletons defined in the warpers.py file. + + To make a new warper/sampler: + - Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`, + and `value_is_valid()`. Dynamic and static methods are seperated for Jax + due to how it does JIT compilation of functions (from what I gather). + These `static` methods are very picky about what you can and can't do + with data at runtime and thus sometimes need to be implemented + differently than the `dynamic` methods, which are more like the Torch + methods. + - Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static. + - Add it to the UI/sampler_order. + + To implement the samplers on a new model type/interface, assuming you're + dealing with Torch tensors, iterate over Warpers from sampler_order using + `Warper.from_id()`, and apply changes with the `torch()` methods. + """ + @staticmethod def from_id(warper_id: int) -> Warper: return { @@ -80,6 +99,22 @@ class Warper: 6: RepetitionPenalty, }[warper_id] + @classmethod + def torch(cls, scores: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Please override `torch()`.") + + @classmethod + def jax_dynamic(cls, scores: np.array) -> np.array: + raise NotImplementedError("Please override `jax_dynamic()`.") + + @classmethod + def jax_static(cls, scores: jnp.array) -> jnp.array: + raise NotImplementedError("Please override `jax_static()`.") + + @classmethod + def value_is_valid(cls) -> bool: + raise NotImplementedError("Please override `value_is_valid()`.") + class Temperature(Warper): """Temperature (just divide the logits by the temperature)""" @@ -534,8 +569,8 @@ class RepetitionPenalty(Warper): generated_index, ) -> jnp.array: """ - This gets called by generate_loop_fn to apply repetition penalty - to the 1D array logits using the provided 1D array of tokens to penalize + This gets called to apply repetition penalty to the 1D array logits + using the provided 1D array of tokens to penalize """ rpslope = jnp.int32(cls.rep_pen_slope) rprange = jnp.int32(cls.rep_pen_range)