from __future__ import annotations from dataclasses import dataclass import time from typing import List, Optional, Union from logger import logger import torch import numpy as np import transformers from transformers import ( GPT2Tokenizer, AutoTokenizer, ) import utils try: import tpu_mtj_backend except ModuleNotFoundError as e: # Not on TPU... hopefully if utils.koboldai_vars.use_colab_tpu: raise e # I don't really like this way of pointing to the current model but I can't # find a way around it in some areas. current_model = None # We only want to use logit manipulations and such on our core text model class use_core_manipulations: """Use in a `with` block to patch functions for core story model sampling.""" # These must be set by wherever they get setup get_logits_processor: callable = None sample: callable = None get_stopping_criteria: callable = None # We set these automatically old_get_logits_processor: callable = None old_sample: callable = None old_get_stopping_criteria: callable = None def __enter__(self): if use_core_manipulations.get_logits_processor: use_core_manipulations.old_get_logits_processor = ( transformers.GenerationMixin._get_logits_processor ) transformers.GenerationMixin._get_logits_processor = ( use_core_manipulations.get_logits_processor ) if use_core_manipulations.sample: use_core_manipulations.old_sample = transformers.GenerationMixin.sample transformers.GenerationMixin.sample = use_core_manipulations.sample if use_core_manipulations.get_stopping_criteria: use_core_manipulations.old_get_stopping_criteria = ( transformers.GenerationMixin._get_stopping_criteria ) transformers.GenerationMixin._get_stopping_criteria = ( use_core_manipulations.get_stopping_criteria ) return self def __exit__(self, exc_type, exc_value, exc_traceback): if use_core_manipulations.old_get_logits_processor: transformers.GenerationMixin._get_logits_processor = ( use_core_manipulations.old_get_logits_processor ) else: assert ( not use_core_manipulations.get_logits_processor ), "Patch leak: THE MONKEYS HAVE ESCAPED" if use_core_manipulations.old_sample: transformers.GenerationMixin.sample = use_core_manipulations.old_sample else: assert ( not use_core_manipulations.sample ), "Patch leak: THE MONKEYS HAVE ESCAPED" if use_core_manipulations.old_get_stopping_criteria: transformers.GenerationMixin._get_stopping_criteria = ( use_core_manipulations.old_get_stopping_criteria ) else: assert ( not use_core_manipulations.get_stopping_criteria ), "Patch leak: THE MONKEYS HAVE ESCAPED" class GenerationResult: """A container for easily accessing different forms of model outputs. Returned by most generate functions.""" def __init__( self, model: InferenceModel, out_batches: list, prompt: list, # Controls if generate() does it's looping thing. This should only be # done for HF models that use that StoppingCondition is_whole_generation: bool, # Controls if we should trim output by prompt length output_includes_prompt: bool = False, # Lazy filter to cut off extra lines where we can't manipulate # probabilities single_line: bool = False, ): # Shave prompt off of encoded response when needed (HF). Decoded does # not return prompt. if output_includes_prompt: self.encoded = out_batches[:, len(prompt) :] else: self.encoded = out_batches self.prompt = prompt self.is_whole_generation = is_whole_generation self.decoded = [ utils.decodenewlines(model.tokenizer.decode(enc)) for enc in self.encoded ] if single_line: self.decoded = [x.split("\n", 1)[0] for x in self.decoded] self.encoded = np.array(model.tokenizer(self.decoded).input_ids) class GenerationSettings: """Structure for holding temporarily overwritten settings.""" def __init__(self, **overrides) -> None: for setting in [ "temp", "top_p", "top_k", "tfs", "typical", "top_a", "rep_pen", "rep_pen_slope", "rep_pen_range", "sampler_order", ]: setattr( self, setting, overrides.get(setting, getattr(utils.koboldai_vars, setting)), ) @dataclass class ModelCapabilities: embedding_manipulation: bool = False post_token_hooks: bool = False stopper_hooks: bool = False # TODO: Support non-live probabilities from APIs post_token_probs: bool = False class InferenceModel: """Root class for all models.""" def __init__(self) -> None: self.gen_state = {} self.post_token_hooks = [] self.stopper_hooks = [] self.tokenizer = None self.capabilties = ModelCapabilities() def load(self, save_model: bool = False, initial_load: bool = False) -> None: """User-facing load function. Do not override this; try `_load()` instead.""" self._load(save_model=save_model, initial_load=initial_load) self._post_load() global current_model current_model = self def _post_load(self) -> None: """Post load hook. Called after `_load()`.""" def _load(self, save_model: bool, initial_load: bool) -> None: """Main load method. All logic related to loading the model onto the selected device(s) and preparing it for inference should be implemented here.""" raise NotImplementedError def _get_tokenizer(self, location: str) -> AutoTokenizer: """Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`. Args: location (str): Either a local model directory path or a HuggingFace model ID. Returns: AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer. """ if utils.koboldai_vars.model_type == "xglm": # Default to newline mode if using XGLM utils.koboldai_vars.newlinemode = "s" elif utils.koboldai_vars.model_type in ["opt", "bloom"]: # Handle but don't convert newlines if using Fairseq models that have newlines trained in them utils.koboldai_vars.newlinemode = "ns" std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"} suppliers = [ # Fast tokenizer disabled by default as per HF docs: # > Note: Make sure to pass use_fast=False when loading # OPT’s tokenizer with AutoTokenizer to get the correct # tokenizer. lambda: AutoTokenizer.from_pretrained( location, use_fast=False, **std_kwargs ), lambda: AutoTokenizer.from_pretrained(location, **std_kwargs), # Fallback to GPT2Tokenizer lambda: GPT2Tokenizer.from_pretrained(location, **std_kwargs), lambda: GPT2Tokenizer.from_pretrained("gpt2", **std_kwargs), ] for i, try_get_tokenizer in enumerate(suppliers): try: return try_get_tokenizer() except: # If we error on each attempt, raise the last one if i == len(suppliers) - 1: raise def core_generate( self, text: list, found_entries: set, ): """Generate story text. Heavily tied to story-specific parameters; if you are making a new generation-based feature, consider `generate_raw()`. Args: text (list): Encoded input tokens found_entries (set): Entries found for Dynamic WI Raises: RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check """ start_time = time.time() gen_in = torch.tensor(text, dtype=torch.long)[None] logger.debug( "core_generate: torch.tensor time {}s".format(time.time() - start_time) ) start_time = time.time() if utils.koboldai_vars.is_model_torch(): # Torch stuff if utils.koboldai_vars.full_determinism: torch.manual_seed(utils.koboldai_vars.seed) if utils.koboldai_vars.sp is not None: assert self.capabilties.embedding_manipulation soft_tokens = torch.arange( self.model.config.vocab_size, self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0], ) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) elif utils.koboldai_vars.use_colab_tpu: if utils.koboldai_vars.full_determinism: tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed) logger.debug( "core_generate: Model Setup (SP, etc) time {}s".format( time.time() - start_time ) ) if ( gen_in.shape[-1] + utils.koboldai_vars.genamt > utils.koboldai_vars.max_length ): logger.error("gen_in.shape[-1]: {}".format(gen_in.shape[-1])) logger.error( "utils.koboldai_vars.genamt: {}".format(utils.koboldai_vars.genamt) ) logger.error( "utils.koboldai_vars.max_length: {}".format( utils.koboldai_vars.max_length ) ) assert ( gen_in.shape[-1] + utils.koboldai_vars.genamt <= utils.koboldai_vars.max_length ) start_time = time.time() gen_in = gen_in.to(utils.get_auxilary_device()) logger.debug( "core_generate: gen_in to device time {}s".format(time.time() - start_time) ) start_time = time.time() found_entries = found_entries or set() self.gen_state["wi_scanner_excluded_keys"] = found_entries utils.koboldai_vars._prompt = utils.koboldai_vars.prompt with torch.no_grad(): already_generated = 0 numseqs = utils.koboldai_vars.numseqs total_gens = None for i in range( utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1 ): while True: # The reason this is a loop is due to how Dynamic WI works. We # cannot simply add the WI to the context mid-generation, so we # stop early, and then insert WI, then continue generating. That # stopping and continuing is this loop. start_time = time.time() result = self.raw_generate( gen_in[0], max_new=utils.koboldai_vars.genamt, do_streaming=utils.koboldai_vars.output_streaming, do_dynamic_wi=utils.koboldai_vars.dynamicscan, batch_count=numseqs if not utils.koboldai_vars.alt_multi_gen else 1, # Real max length is handled by CoreStopper. bypass_hf_maxlength=utils.koboldai_vars.dynamicscan, is_core=True, tpu_dynamic_inference=utils.koboldai_vars.dynamicscan or ( not utils.koboldai_vars.nogenmod and utils.koboldai_vars.has_genmod ), ) logger.debug( "core_generate: run raw_generate pass {} {}s".format( already_generated, time.time() - start_time ) ) genout = result.encoded already_generated += len(genout[0]) try: assert ( already_generated <= utils.koboldai_vars.genamt * utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1 ) except AssertionError: print("AlreadyGenerated", already_generated) print("genamt", utils.koboldai_vars.genamt) raise if result.is_whole_generation: break # Generation stopped; why? # If we have been told to halt, we have reached our target token # amount (controlled by halt), or Dynamic WI has not told us to # stop temporarily to insert WI, we can assume that we are done # generating. We shall break. if ( self.gen_state["halt"] or not self.gen_state["regeneration_required"] ): break # Now we are doing stuff for Dynamic WI. assert genout.ndim >= 2 assert genout.shape[0] == utils.koboldai_vars.numseqs if ( utils.koboldai_vars.lua_koboldbridge.generated_cols and utils.koboldai_vars.generated_tkns != utils.koboldai_vars.lua_koboldbridge.generated_cols ): raise RuntimeError( f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})" ) if already_generated != utils.koboldai_vars.generated_tkns: print("already_generated: {}".format(already_generated)) print( "generated_tkns: {}".format( utils.koboldai_vars.generated_tkns ) ) raise RuntimeError("WI scanning error") for r in range(utils.koboldai_vars.numseqs): for c in range(already_generated): assert ( utils.koboldai_vars.lua_koboldbridge.generated[r + 1][ c + 1 ] is not None ) genout[r][ genout.shape[-1] - already_generated + c ] = utils.koboldai_vars.lua_koboldbridge.generated[r + 1][ c + 1 ] encoded = [] for i in range(utils.koboldai_vars.numseqs): txt = utils.decodenewlines( self.tokenizer.decode(genout[i, -already_generated:]) ) # winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=utils.koboldai_vars.actions) # txt, _, _ = calcsubmitbudget(len(utils.koboldai_vars.actions), winfo, mem, anotetxt, utils.koboldai_vars.actions, submission=txt) txt, _, _, _found_entries = utils.koboldai_vars.calc_ai_text( submitted_text=txt, send_context=False ) found_entries[i].update(_found_entries) encoded.append( torch.tensor(txt, dtype=torch.long, device=genout.device) ) max_length = len(max(encoded, key=len)) encoded = torch.stack( tuple( torch.nn.functional.pad( e, (max_length - len(e), 0), value=self.model.config.pad_token_id or self.model.config.eos_token_id, ) for e in encoded ) ) genout = torch.cat( ( encoded, genout[..., -already_generated:], ), dim=-1, ) if utils.koboldai_vars.sp is not None: soft_tokens = torch.arange( self.model.config.vocab_size, self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0], device=genout.device, ) genout = torch.cat( (soft_tokens.tile(utils.koboldai_vars.numseqs, 1), genout), dim=-1, ) assert ( genout.shape[-1] + utils.koboldai_vars.genamt - already_generated <= utils.koboldai_vars.max_length ) gen_in = genout numseqs = 1 if total_gens is None: total_gens = genout else: total_gens = torch.cat((total_gens, genout)) return total_gens, already_generated def _raw_generate( self, prompt_tokens: Union[List[int], torch.Tensor], max_new: int, gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, **kwargs, ) -> GenerationResult: """Lowest level model-agnostic generation function. To be overridden by model implementation. Args: prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs max_new (int): Maximum amount of new tokens to generate gen_settings (GenerationSettings): State to pass in single-generation setting overrides single_line (bool, optional): Generate one line only. Defaults to False. batch_count (int, optional): How big of a batch to generate. Defaults to 1. Returns: GenerationResult: The model's output """ raise NotImplementedError def raw_generate( self, # prompt is either a string (text) or a list (token ids) prompt: Union[str, list, np.ndarray], max_new: int, do_streaming: bool = False, do_dynamic_wi: bool = False, batch_count: int = 1, bypass_hf_maxlength: bool = False, generation_settings: Optional[dict] = None, is_core: bool = False, single_line: bool = False, found_entries: set = (), tpu_dynamic_inference: bool = False, **kwargs, ) -> GenerationResult: """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story. Args: prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs max_new (int): Maximum amount of new tokens to generate do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False. do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False. batch_count (int, optional): How big of a batch to generate. Defaults to 1. bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False. generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None is_core (bool, optional): Whether this generation is a core story generation. Defaults to False. single_line (bool, optional): Generate one line only.. Defaults to False. found_entries (set, optional): Entries found for Dynamic WI. Defaults to (). Raises: ValueError: If prompt type is weird NotImplementedError: If model is ReadOnly Returns: GenerationResult: The model's output """ # TODO: Support singleline outside of torch self.gen_state["do_streaming"] = do_streaming self.gen_state["do_dynamic_wi"] = do_dynamic_wi # Dynamic WI depends on this!!! This is a main gen call. self.gen_state["stop_at_genamt"] = do_dynamic_wi # Makes stopping criteria hook happy self.gen_state["wi_scanner_excluded_keys"] = self.gen_state.get( "wi_scanner_excluded_keys", set() ) utils.koboldai_vars.inference_config.do_core = is_core gen_settings = GenerationSettings(*(generation_settings or {})) if isinstance(prompt, torch.Tensor): prompt_tokens = prompt.cpu().numpy() elif isinstance(prompt, list): prompt_tokens = np.array(prompt) elif isinstance(prompt, str): prompt_tokens = np.array(self.tokenizer.encode(prompt)) else: raise ValueError(f"Prompt is {type(prompt)}. Not a fan!") assert isinstance(prompt_tokens, np.ndarray) assert len(prompt_tokens.shape) == 1 if utils.koboldai_vars.model == "ReadOnly": raise NotImplementedError("No loaded model") time_start = time.time() with use_core_manipulations(): result = self._raw_generate( prompt_tokens=prompt_tokens, max_new=max_new, batch_count=batch_count, gen_settings=gen_settings, single_line=single_line, tpu_dynamic_inference=tpu_dynamic_inference, ) time_end = round(time.time() - time_start, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2) if not utils.koboldai_vars.quiet: logger.info( f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second." ) return result def generate( self, prompt_tokens: Union[List[int], torch.Tensor], max_new_tokens: int, do_streaming: bool = False, do_dynamic_wi: bool = False, single_line: bool = False, batch_count: int = 1, ) -> torch.Tensor: raise NotImplementedError def _post_token_gen(self, input_ids: torch.LongTensor) -> None: for hook in self.post_token_hooks: hook(self, input_ids)