Model: And another refactor

This commit is contained in:
somebody
2023-03-01 19:16:35 -06:00
parent 225dcf1a0a
commit 54cecd4d5d
18 changed files with 3045 additions and 2911 deletions

591
modeling/inference_model.py Normal file
View File

@@ -0,0 +1,591 @@
from __future__ import annotations
from dataclasses import dataclass
import time
from typing import List, Optional, Union
from logger import logger
import torch
import numpy as np
import transformers
from transformers import (
GPT2Tokenizer,
AutoTokenizer,
)
import utils
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
# Not on TPU... hopefully
if utils.koboldai_vars.use_colab_tpu:
raise e
# I don't really like this way of pointing to the current model but I can't
# find a way around it in some areas.
current_model = None
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling."""
# These must be set by wherever they get setup
get_logits_processor: callable = None
sample: callable = None
get_stopping_criteria: callable = None
# We set these automatically
old_get_logits_processor: callable = None
old_sample: callable = None
old_get_stopping_criteria: callable = None
def __enter__(self):
if use_core_manipulations.get_logits_processor:
use_core_manipulations.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor
)
transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.get_logits_processor
)
if use_core_manipulations.sample:
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
transformers.GenerationMixin.sample = use_core_manipulations.sample
if use_core_manipulations.get_stopping_criteria:
use_core_manipulations.old_get_stopping_criteria = (
transformers.GenerationMixin._get_stopping_criteria
)
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.get_stopping_criteria
)
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if use_core_manipulations.old_get_logits_processor:
transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.old_get_logits_processor
)
else:
assert (
not use_core_manipulations.get_logits_processor
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else:
assert (
not use_core_manipulations.sample
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria
)
else:
assert (
not use_core_manipulations.get_stopping_criteria
), "Patch leak: THE MONKEYS HAVE ESCAPED"
class GenerationResult:
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
def __init__(
self,
model: InferenceModel,
out_batches: list,
prompt: list,
# Controls if generate() does it's looping thing. This should only be
# done for HF models that use that StoppingCondition
is_whole_generation: bool,
# Controls if we should trim output by prompt length
output_includes_prompt: bool = False,
# Lazy filter to cut off extra lines where we can't manipulate
# probabilities
single_line: bool = False,
):
# Shave prompt off of encoded response when needed (HF). Decoded does
# not return prompt.
if output_includes_prompt:
self.encoded = out_batches[:, len(prompt) :]
else:
self.encoded = out_batches
self.prompt = prompt
self.is_whole_generation = is_whole_generation
self.decoded = [
utils.decodenewlines(model.tokenizer.decode(enc)) for enc in self.encoded
]
if single_line:
self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
self.encoded = np.array(model.tokenizer(self.decoded).input_ids)
class GenerationSettings:
"""Structure for holding temporarily overwritten settings."""
def __init__(self, **overrides) -> None:
for setting in [
"temp",
"top_p",
"top_k",
"tfs",
"typical",
"top_a",
"rep_pen",
"rep_pen_slope",
"rep_pen_range",
"sampler_order",
]:
setattr(
self,
setting,
overrides.get(setting, getattr(utils.koboldai_vars, setting)),
)
@dataclass
class ModelCapabilities:
embedding_manipulation: bool = False
post_token_hooks: bool = False
stopper_hooks: bool = False
# TODO: Support non-live probabilities from APIs
post_token_probs: bool = False
class InferenceModel:
"""Root class for all models."""
def __init__(self) -> None:
self.gen_state = {}
self.post_token_hooks = []
self.stopper_hooks = []
self.tokenizer = None
self.capabilties = ModelCapabilities()
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""User-facing load function. Do not override this; try `_load()` instead."""
self._load(save_model=save_model, initial_load=initial_load)
self._post_load()
global current_model
current_model = self
print(self.raw_generate("Hi guys,", 20).__dict__)
def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""
def _load(self, save_model: bool, initial_load: bool) -> None:
"""Main load method. All logic related to loading the model onto the
selected device(s) and preparing it for inference should be implemented here."""
raise NotImplementedError
def _get_tokenizer(self, location: str) -> AutoTokenizer:
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
Args:
location (str): Either a local model directory path or a HuggingFace model ID.
Returns:
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
"""
if utils.koboldai_vars.model_type == "xglm":
# Default to </s> newline mode if using XGLM
utils.koboldai_vars.newlinemode = "s"
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
utils.koboldai_vars.newlinemode = "ns"
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
suppliers = [
# Fast tokenizer disabled by default as per HF docs:
# > Note: Make sure to pass use_fast=False when loading
# OPTs tokenizer with AutoTokenizer to get the correct
# tokenizer.
lambda: AutoTokenizer.from_pretrained(
location, use_fast=False, **std_kwargs
),
lambda: AutoTokenizer.from_pretrained(location, **std_kwargs),
# Fallback to GPT2Tokenizer
lambda: GPT2Tokenizer.from_pretrained(location, **std_kwargs),
lambda: GPT2Tokenizer.from_pretrained("gpt2", **std_kwargs),
]
for i, try_get_tokenizer in enumerate(suppliers):
try:
return try_get_tokenizer()
except:
# If we error on each attempt, raise the last one
if i == len(suppliers) - 1:
raise
def core_generate(
self,
text: list,
found_entries: set,
):
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
"""
start_time = time.time()
gen_in = torch.tensor(text, dtype=torch.long)[None]
logger.debug(
"core_generate: torch.tensor time {}s".format(time.time() - start_time)
)
start_time = time.time()
if utils.koboldai_vars.is_model_torch():
# Torch stuff
if utils.koboldai_vars.full_determinism:
torch.manual_seed(utils.koboldai_vars.seed)
if utils.koboldai_vars.sp is not None:
assert self.capabilties.embedding_manipulation
soft_tokens = torch.arange(
self.model.config.vocab_size,
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
)
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
elif utils.koboldai_vars.use_colab_tpu:
if utils.koboldai_vars.full_determinism:
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
logger.debug(
"core_generate: Model Setup (SP, etc) time {}s".format(
time.time() - start_time
)
)
if (
gen_in.shape[-1] + utils.koboldai_vars.genamt
> utils.koboldai_vars.max_length
):
logger.error("gen_in.shape[-1]: {}".format(gen_in.shape[-1]))
logger.error(
"utils.koboldai_vars.genamt: {}".format(utils.koboldai_vars.genamt)
)
logger.error(
"utils.koboldai_vars.max_length: {}".format(
utils.koboldai_vars.max_length
)
)
assert (
gen_in.shape[-1] + utils.koboldai_vars.genamt
<= utils.koboldai_vars.max_length
)
start_time = time.time()
gen_in = gen_in.to(utils.get_auxilary_device())
logger.debug(
"core_generate: gen_in to device time {}s".format(time.time() - start_time)
)
start_time = time.time()
found_entries = found_entries or set()
self.gen_state["wi_scanner_excluded_keys"] = found_entries
utils.koboldai_vars._prompt = utils.koboldai_vars.prompt
with torch.no_grad():
already_generated = 0
numseqs = utils.koboldai_vars.numseqs
total_gens = None
for i in range(
utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1
):
while True:
# The reason this is a loop is due to how Dynamic WI works. We
# cannot simply add the WI to the context mid-generation, so we
# stop early, and then insert WI, then continue generating. That
# stopping and continuing is this loop.
start_time = time.time()
result = self.raw_generate(
gen_in[0],
max_new=utils.koboldai_vars.genamt,
do_streaming=utils.koboldai_vars.output_streaming,
do_dynamic_wi=utils.koboldai_vars.dynamicscan,
batch_count=numseqs
if not utils.koboldai_vars.alt_multi_gen
else 1,
# Real max length is handled by CoreStopper.
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
is_core=True,
)
logger.debug(
"core_generate: run raw_generate pass {} {}s".format(
already_generated, time.time() - start_time
)
)
genout = result.encoded
already_generated += len(genout[0])
try:
assert (
already_generated
<= utils.koboldai_vars.genamt * utils.koboldai_vars.numseqs
if utils.koboldai_vars.alt_multi_gen
else 1
)
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", utils.koboldai_vars.genamt)
raise
if result.is_whole_generation:
break
# Generation stopped; why?
# If we have been told to halt, we have reached our target token
# amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if (
self.gen_state["halt"]
or not self.gen_state["regeneration_required"]
):
break
# Now we are doing stuff for Dynamic WI.
assert genout.ndim >= 2
assert genout.shape[0] == utils.koboldai_vars.numseqs
if (
utils.koboldai_vars.lua_koboldbridge.generated_cols
and utils.koboldai_vars.generated_tkns
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
):
raise RuntimeError(
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
)
if already_generated != utils.koboldai_vars.generated_tkns:
print("already_generated: {}".format(already_generated))
print(
"generated_tkns: {}".format(
utils.koboldai_vars.generated_tkns
)
)
raise RuntimeError("WI scanning error")
for r in range(utils.koboldai_vars.numseqs):
for c in range(already_generated):
assert (
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
c + 1
]
is not None
)
genout[r][
genout.shape[-1] - already_generated + c
] = utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
c + 1
]
encoded = []
for i in range(utils.koboldai_vars.numseqs):
txt = utils.decodenewlines(
self.tokenizer.decode(genout[i, -already_generated:])
)
# winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=utils.koboldai_vars.actions)
# txt, _, _ = calcsubmitbudget(len(utils.koboldai_vars.actions), winfo, mem, anotetxt, utils.koboldai_vars.actions, submission=txt)
txt, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
submitted_text=txt, send_context=False
)
found_entries[i].update(_found_entries)
encoded.append(
torch.tensor(txt, dtype=torch.long, device=genout.device)
)
max_length = len(max(encoded, key=len))
encoded = torch.stack(
tuple(
torch.nn.functional.pad(
e,
(max_length - len(e), 0),
value=self.model.config.pad_token_id
or self.model.config.eos_token_id,
)
for e in encoded
)
)
genout = torch.cat(
(
encoded,
genout[..., -already_generated:],
),
dim=-1,
)
if utils.koboldai_vars.sp is not None:
soft_tokens = torch.arange(
self.model.config.vocab_size,
self.model.config.vocab_size
+ utils.koboldai_vars.sp.shape[0],
device=genout.device,
)
genout = torch.cat(
(soft_tokens.tile(utils.koboldai_vars.numseqs, 1), genout),
dim=-1,
)
assert (
genout.shape[-1]
+ utils.koboldai_vars.genamt
- already_generated
<= utils.koboldai_vars.max_length
)
gen_in = genout
numseqs = 1
if total_gens is None:
total_gens = genout
else:
total_gens = torch.cat((total_gens, genout))
return total_gens, already_generated
def _raw_generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new: int,
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
) -> GenerationResult:
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
Args:
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
max_new (int): Maximum amount of new tokens to generate
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
Returns:
GenerationResult: The model's output
"""
raise NotImplementedError
def raw_generate(
self,
# prompt is either a string (text) or a list (token ids)
prompt: Union[str, list, np.ndarray],
max_new: int,
do_streaming: bool = False,
do_dynamic_wi: bool = False,
batch_count: int = 1,
bypass_hf_maxlength: bool = False,
generation_settings: Optional[dict] = None,
is_core: bool = False,
single_line: bool = False,
found_entries: set = (),
) -> GenerationResult:
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
Args:
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
max_new (int): Maximum amount of new tokens to generate
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
Raises:
ValueError: If prompt type is weird
NotImplementedError: If model is ReadOnly
Returns:
GenerationResult: The model's output
"""
# TODO: Support singleline outside of torch
self.gen_state["do_streaming"] = do_streaming
self.gen_state["do_dynamic_wi"] = do_dynamic_wi
# Dynamic WI depends on this!!! This is a main gen call.
self.gen_state["stop_at_genamt"] = do_dynamic_wi
# Makes stopping criteria hook happy
self.gen_state["wi_scanner_excluded_keys"] = self.gen_state.get(
"wi_scanner_excluded_keys", set()
)
utils.koboldai_vars.inference_config.do_core = is_core
gen_settings = GenerationSettings(*(generation_settings or {}))
if isinstance(prompt, torch.Tensor):
prompt_tokens = prompt.cpu().numpy()
elif isinstance(prompt, list):
prompt_tokens = np.array(prompt)
elif isinstance(prompt, str):
prompt_tokens = np.array(self.tokenizer.encode(prompt))
else:
raise ValueError(f"Prompt is {type(prompt)}. Not a fan!")
assert isinstance(prompt_tokens, np.ndarray)
assert len(prompt_tokens.shape) == 1
if utils.koboldai_vars.model == "ReadOnly":
raise NotImplementedError("No loaded model")
time_start = time.time()
with use_core_manipulations():
result = self._raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings,
single_line=single_line,
)
time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
if not utils.koboldai_vars.quiet:
logger.info(
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
)
return result
def generate(
self,
prompt_tokens: Union[List[int], torch.Tensor],
max_new_tokens: int,
do_streaming: bool = False,
do_dynamic_wi: bool = False,
single_line: bool = False,
batch_count: int = 1,
) -> torch.Tensor:
raise NotImplementedError
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
for hook in self.post_token_hooks:
hook(self, input_ids)