Files
KoboldAI-Client/modeling/inference_model.py
2023-11-08 18:47:02 +01:00

692 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from dataclasses import dataclass
import time
from typing import List, Optional, Union
from enum import Enum
from logger import logger
import torch
import numpy as np
import transformers
from transformers import (
GPT2Tokenizer,
AutoTokenizer,
LlamaTokenizer,
)
from modeling.stoppers import Stoppers
from modeling.tokenizer import GenericTokenizer
from modeling import logits_processors
import utils
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling."""
# These must be set by wherever they get setup
get_logits_processor: callable = None
sample: callable = None
get_stopping_criteria: callable = None
# We set these automatically
old_get_logits_processor: callable = None
old_sample: callable = None
old_get_stopping_criteria: callable = None
def __enter__(self):
if use_core_manipulations.get_logits_processor:
use_core_manipulations.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor
)
transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.get_logits_processor
)
if use_core_manipulations.sample:
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
transformers.GenerationMixin.sample = use_core_manipulations.sample
if use_core_manipulations.get_stopping_criteria:
use_core_manipulations.old_get_stopping_criteria = (
transformers.GenerationMixin._get_stopping_criteria
)
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.get_stopping_criteria
)
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if use_core_manipulations.old_get_logits_processor:
transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.old_get_logits_processor
)
else:
assert (
not use_core_manipulations.get_logits_processor
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else:
assert (
not use_core_manipulations.sample
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria
)
else:
assert (
not use_core_manipulations.get_stopping_criteria
), "Patch leak: THE MONKEYS HAVE ESCAPED"
class GenerationResult:
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
def __init__(
self,
model: InferenceModel,
out_batches: list,
prompt: list,
# Controls if generate() does it's looping thing. This should only be
# done for HF models that use that StoppingCondition
is_whole_generation: bool,
# Controls if we should trim output by prompt length
output_includes_prompt: bool = False,
# Lazy filter to cut off extra lines where we can't manipulate
# probabilities
single_line: bool = False,
):
# Shave prompt off of encoded response when needed (HF). Decoded does
# not return prompt.
if output_includes_prompt:
self.encoded = out_batches[:, len(prompt) :]
else:
self.encoded = out_batches
self.prompt = prompt
self.is_whole_generation = is_whole_generation
self.decoded = [
utils.decodenewlines(model.tokenizer.decode(enc)) for enc in self.encoded
]
if single_line:
self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
self.encoded = np.array(model.tokenizer(self.decoded).input_ids)
class GenerationSettings:
"""Structure for holding temporarily overwritten settings."""
def __init__(self, **overrides) -> None:
for setting in [
"temp",
"top_p",
"top_k",
"tfs",
"typical",
"top_a",
"rep_pen",
"rep_pen_slope",
"rep_pen_range",
"sampler_order",
]:
setattr(
self,
setting,
overrides.get(setting, getattr(utils.koboldai_vars, setting)),
)
@dataclass
class ModelCapabilities:
embedding_manipulation: bool = False
post_token_hooks: bool = False
# Used to gauge if manual stopping is possible
stopper_hooks: bool = False
# TODO: Support non-live probabilities from APIs
post_token_probs: bool = False
# Some models cannot be hosted over the API, namely the API itself.
api_host: bool = True
# Some models need to warm up the TPU before use
uses_tpu: bool = False
class GenerationMode(Enum):
STANDARD = "standard"
FOREVER = "forever"
UNTIL_EOS = "until_eos"
UNTIL_NEWLINE = "until_newline"
UNTIL_SENTENCE_END = "until_sentence_end"
class InferenceModel:
"""Root class for all models."""
def __init__(self) -> None:
self.abort = False
self.gen_state = {}
self.post_token_hooks = []
self.stopper_hooks = []
self.logits_processors = [
logits_processors.LuaIntegration(),
logits_processors.PhraseBiasLogitsProcessor(),
logits_processors.ProbabilityVisualization(),
]
self.tokenizer = None
self.capabilties = ModelCapabilities()
self.model_name = "Not Defined"
def is_valid(self, model_name, model_path, menu_path):
return True
def requested_parameters(self, model_name, model_path, menu_path):
return {}
def set_input_parameters(self, parameters):
for parameter in parameters:
setattr(self, parameter, parameters[parameter])
return
def get_auxilary_device(self) -> Union[str, int, torch.device]:
"""Get device auxilary tensors like inputs should be stored on."""
# NOTE: TPU isn't a torch device, so TPU stuff gets sent to CPU.
if utils.koboldai_vars.hascuda and utils.koboldai_vars.usegpu:
return utils.koboldai_vars.gpu_device
elif utils.koboldai_vars.hascuda:
return "cuda"
return "cpu"
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""User-facing load function. Do not override this; try `_load()` instead."""
self._pre_load()
self._load(save_model=save_model, initial_load=initial_load)
self._post_load()
self._save_settings()
def unload(self):
return
def _pre_load(self) -> None:
"""Pre load hook. Called before `_load()`."""
def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""
def _save_settings(self) -> None:
"""Save settings hook. Called after `_post_load()`."""
def _load(self, save_model: bool, initial_load: bool) -> None:
"""Main load method. All logic related to loading the model onto the
selected device(s) and preparing it for inference should be implemented here."""
raise NotImplementedError
def _get_tokenizer(self, location: str) -> GenericTokenizer:
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
Args:
location (str): Either a local model directory path or a HuggingFace model ID.
Returns:
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
"""
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
suppliers = [
# Fast tokenizer disabled by default as per HF docs:
# > Note: Make sure to pass use_fast=False when loading
# OPTs tokenizer with AutoTokenizer to get the correct
# tokenizer.
lambda: AutoTokenizer.from_pretrained(
location, use_fast=False, **std_kwargs
),
lambda: AutoTokenizer.from_pretrained(location, **std_kwargs),
# Attempt more basic GPT2 Tokenizer
lambda: GPT2Tokenizer.from_pretrained(location, **std_kwargs),
# Fallback to generic LLaMA Tokenizer
lambda: LlamaTokenizer.from_pretrained("KoboldAI/llama2-tokenizer", use_fast=False, **std_kwargs),
]
for i, try_get_tokenizer in enumerate(suppliers):
try:
return GenericTokenizer(try_get_tokenizer())
except Exception as e:
logger.warning(f"Tokenizer falling back due to {e} (This can be normal behavior for some architectures that lack a slow tokenizer such as NeoX)")
# If we error on each attempt, raise the last one
if i == len(suppliers) - 1:
raise
def core_generate(
self,
text: list,
found_entries: set,
gen_mode: GenerationMode = GenerationMode.STANDARD,
):
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
gen_mode (GenerationMode): The GenerationMode to pass to raw_generate. Defaults to GenerationMode.STANDARD
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
"""
start_time = time.time()
gen_in = torch.tensor(text, dtype=torch.long)[None]
logger.debug(
"core_generate: torch.tensor time {}s".format(time.time() - start_time)
)
start_time = time.time()
if utils.koboldai_vars.is_model_torch():
# Torch stuff
if utils.koboldai_vars.sp is not None:
assert self.capabilties.embedding_manipulation
soft_tokens = torch.arange(
self.model.config.vocab_size,
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
logger.debug(
"core_generate: Model Setup (SP, etc) time {}s".format(
time.time() - start_time
)
)
if (
gen_in.shape[-1] + utils.koboldai_vars.genamt
> utils.koboldai_vars.max_length
):
logger.error("gen_in.shape[-1]: {}".format(gen_in.shape[-1]))
logger.error(
"utils.koboldai_vars.genamt: {}".format(utils.koboldai_vars.genamt)
)
logger.error(
"utils.koboldai_vars.max_length: {}".format(
utils.koboldai_vars.max_length
)
)
assert (
gen_in.shape[-1] + utils.koboldai_vars.genamt
<= utils.koboldai_vars.max_length
)
start_time = time.time()
gen_in = gen_in.to(self.get_auxilary_device())
logger.debug(
"core_generate: gen_in to device time {}s".format(time.time() - start_time)
)
start_time = time.time()
found_entries = found_entries or set()
self.gen_state["wi_scanner_excluded_keys"] = found_entries
utils.koboldai_vars._prompt = utils.koboldai_vars.prompt
with torch.no_grad():
already_generated = 0
numseqs = utils.koboldai_vars.numseqs
total_gens = None
for i in range(
utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1
):
while True:
# The reason this is a loop is due to how Dynamic WI works. We
# cannot simply add the WI to the context mid-generation, so we
# stop early, and then insert WI, then continue generating. That
# stopping and continuing is this loop.
start_time = time.time()
result = self.raw_generate(
gen_in[0],
max_new=utils.koboldai_vars.genamt,
do_streaming=utils.koboldai_vars.output_streaming,
do_dynamic_wi=utils.koboldai_vars.dynamicscan,
batch_count=numseqs
if not utils.koboldai_vars.alt_multi_gen
else 1,
# Real max length is handled by CoreStopper.
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
is_core=True,
tpu_dynamic_inference=utils.koboldai_vars.dynamicscan
or (
not utils.koboldai_vars.nogenmod
and utils.koboldai_vars.has_genmod
),
seed=utils.koboldai_vars.seed
if utils.koboldai_vars.full_determinism
else None,
gen_mode=gen_mode
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
already_generated, time.time() - start_time
)
)
genout = result.encoded
already_generated += len(genout[0])
try:
assert (
already_generated
<= utils.koboldai_vars.genamt * utils.koboldai_vars.numseqs
if utils.koboldai_vars.alt_multi_gen
else 1
)
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", utils.koboldai_vars.genamt)
raise
if result.is_whole_generation:
break
# Generation stopped; why?
# If we have been told to halt, we have reached our target token
# amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if self.gen_state.get("halt") or not self.gen_state.get(
"regeneration_required"
):
break
# Now we are doing stuff for Dynamic WI.
assert genout.ndim >= 2
assert genout.shape[0] == utils.koboldai_vars.numseqs
if (
utils.koboldai_vars.lua_koboldbridge.generated_cols
and utils.koboldai_vars.generated_tkns
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
):
raise RuntimeError(
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
)
if already_generated != utils.koboldai_vars.generated_tkns:
print("already_generated: {}".format(already_generated))
print(
"generated_tkns: {}".format(
utils.koboldai_vars.generated_tkns
)
)
raise RuntimeError("WI scanning error")
for r in range(utils.koboldai_vars.numseqs):
for c in range(already_generated):
assert (
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
c + 1
]
is not None
)
genout[r][
genout.shape[-1] - already_generated + c
] = utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
c + 1
]
encoded = []
for i in range(utils.koboldai_vars.numseqs):
txt = utils.decodenewlines(
self.tokenizer.decode(genout[i, -already_generated:])
)
# winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=utils.koboldai_vars.actions)
# txt, _, _ = calcsubmitbudget(len(utils.koboldai_vars.actions), winfo, mem, anotetxt, utils.koboldai_vars.actions, submission=txt)
txt, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
submitted_text=txt, send_context=False
)
found_entries[i].update(_found_entries)
encoded.append(
torch.tensor(txt, dtype=torch.long, device=genout.device)
)
max_length = len(max(encoded, key=len))
encoded = torch.stack(
tuple(
torch.nn.functional.pad(
e,
(max_length - len(e), 0),
value=self.model.config.pad_token_id
or self.model.config.eos_token_id,
)
for e in encoded
)
)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1,
)
if utils.koboldai_vars.sp is not None:
soft_tokens = torch.arange(
self.model.config.vocab_size,
self.model.config.vocab_size
+ utils.koboldai_vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat(
(soft_tokens.tile(utils.koboldai_vars.numseqs, 1), genout),
dim=-1,
)
assert (
genout.shape[-1]
+ utils.koboldai_vars.genamt
- already_generated
<= utils.koboldai_vars.max_length
)
gen_in = genout
numseqs = 1
if total_gens is None:
total_gens = genout
else:
total_gens = torch.cat((total_gens, genout))
return total_gens, already_generated
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new: int,
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
**kwargs,
) -> GenerationResult:
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
Args:
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
max_new (int): Maximum amount of new tokens to generate
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
seed (int, optional): If not None, this seed will be used to make reproducible generations. Defaults to None.
Returns:
GenerationResult: The model's output
"""
raise NotImplementedError
def raw_generate(
self,
# prompt is either a string (text) or a list (token ids)
prompt: Union[str, list, np.ndarray],
max_new: int,
do_streaming: bool = False,
do_dynamic_wi: bool = False,
batch_count: int = 1,
bypass_hf_maxlength: bool = False,
generation_settings: Optional[dict] = None,
is_core: bool = False,
single_line: bool = False,
found_entries: set = (),
tpu_dynamic_inference: bool = False,
seed: Optional[int] = None,
gen_mode: GenerationMode = GenerationMode.STANDARD,
**kwargs,
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
Args:
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
max_new (int): Maximum amount of new tokens to generate
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
gen_mode (GenerationMode): Special generation mode. Defaults to GenerationMode.STANDARD.
Raises:
ValueError: If prompt type is weird
NotImplementedError: If model is ReadOnly
Returns:
GenerationResult: The model's output
"""
# TODO: Support singleline outside of torch
self.gen_state["do_streaming"] = do_streaming
self.gen_state["do_dynamic_wi"] = do_dynamic_wi
# Dynamic WI depends on this!!! This is a main gen call.
self.gen_state["stop_at_genamt"] = do_dynamic_wi
# Makes stopping criteria hook happy
self.gen_state["wi_scanner_excluded_keys"] = self.gen_state.get(
"wi_scanner_excluded_keys", set()
)
self.gen_state["allow_eos"] = False
temp_stoppers = []
if gen_mode not in self.get_supported_gen_modes():
gen_mode = GenerationMode.STANDARD
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")
if gen_mode == GenerationMode.FOREVER:
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_EOS:
self.gen_state["allow_eos"] = True
self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper)
elif gen_mode == GenerationMode.UNTIL_SENTENCE_END:
temp_stoppers.append(Stoppers.sentence_end_stopper)
self.stopper_hooks += temp_stoppers
utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))
if isinstance(prompt, torch.Tensor):
prompt_tokens = prompt.cpu().numpy()
elif isinstance(prompt, list):
prompt_tokens = np.array(prompt)
elif isinstance(prompt, str):
prompt_tokens = np.array(self.tokenizer.encode(prompt))
else:
raise ValueError(f"Prompt is {type(prompt)}. Not a fan!")
assert isinstance(prompt_tokens, np.ndarray)
assert len(prompt_tokens.shape) == 1
time_start = time.time()
with use_core_manipulations():
result = self._raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings,
single_line=single_line,
tpu_dynamic_inference=tpu_dynamic_inference,
seed=seed,
)
time_end = round(time.time() - time_start, 2)
try:
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
except ZeroDivisionError:
# Introducing KoboldAI's fastest model: ReadOnly!
tokens_per_second = 0
if not utils.koboldai_vars.quiet:
logger.info(
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
)
for stopper in temp_stoppers:
self.stopper_hooks.remove(stopper)
return result
def generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new_tokens: int,
do_streaming: bool = False,
do_dynamic_wi: bool = False,
single_line: bool = False,
batch_count: int = 1,
) -> torch.Tensor:
raise NotImplementedError
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
for hook in self.post_token_hooks:
hook(self, input_ids)
def abort_generation(self, abort=True):
self.abort=abort
def get_supported_gen_modes(self) -> List[GenerationMode]:
"""Returns a list of compatible `GenerationMode`s for the current model.
Returns:
List[GenerationMode]: A list of compatible `GenerationMode`s.
"""
ret = [GenerationMode.STANDARD]
if self.capabilties.stopper_hooks:
ret += [
GenerationMode.FOREVER,
GenerationMode.UNTIL_NEWLINE,
GenerationMode.UNTIL_SENTENCE_END,
]
return ret