mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
592 lines
23 KiB
Python
592 lines
23 KiB
Python
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
import time
|
||
from typing import List, Optional, Union
|
||
from logger import logger
|
||
|
||
import torch
|
||
import numpy as np
|
||
import transformers
|
||
from transformers import (
|
||
GPT2Tokenizer,
|
||
AutoTokenizer,
|
||
)
|
||
|
||
import utils
|
||
|
||
try:
|
||
import tpu_mtj_backend
|
||
except ModuleNotFoundError as e:
|
||
# Not on TPU... hopefully
|
||
if utils.koboldai_vars.use_colab_tpu:
|
||
raise e
|
||
|
||
# I don't really like this way of pointing to the current model but I can't
|
||
# find a way around it in some areas.
|
||
current_model = None
|
||
|
||
# We only want to use logit manipulations and such on our core text model
|
||
class use_core_manipulations:
|
||
"""Use in a `with` block to patch functions for core story model sampling."""
|
||
|
||
# These must be set by wherever they get setup
|
||
get_logits_processor: callable = None
|
||
sample: callable = None
|
||
get_stopping_criteria: callable = None
|
||
|
||
# We set these automatically
|
||
old_get_logits_processor: callable = None
|
||
old_sample: callable = None
|
||
old_get_stopping_criteria: callable = None
|
||
|
||
def __enter__(self):
|
||
if use_core_manipulations.get_logits_processor:
|
||
use_core_manipulations.old_get_logits_processor = (
|
||
transformers.GenerationMixin._get_logits_processor
|
||
)
|
||
transformers.GenerationMixin._get_logits_processor = (
|
||
use_core_manipulations.get_logits_processor
|
||
)
|
||
|
||
if use_core_manipulations.sample:
|
||
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
|
||
transformers.GenerationMixin.sample = use_core_manipulations.sample
|
||
|
||
if use_core_manipulations.get_stopping_criteria:
|
||
use_core_manipulations.old_get_stopping_criteria = (
|
||
transformers.GenerationMixin._get_stopping_criteria
|
||
)
|
||
transformers.GenerationMixin._get_stopping_criteria = (
|
||
use_core_manipulations.get_stopping_criteria
|
||
)
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||
if use_core_manipulations.old_get_logits_processor:
|
||
transformers.GenerationMixin._get_logits_processor = (
|
||
use_core_manipulations.old_get_logits_processor
|
||
)
|
||
else:
|
||
assert (
|
||
not use_core_manipulations.get_logits_processor
|
||
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||
|
||
if use_core_manipulations.old_sample:
|
||
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
|
||
else:
|
||
assert (
|
||
not use_core_manipulations.sample
|
||
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||
|
||
if use_core_manipulations.old_get_stopping_criteria:
|
||
transformers.GenerationMixin._get_stopping_criteria = (
|
||
use_core_manipulations.old_get_stopping_criteria
|
||
)
|
||
else:
|
||
assert (
|
||
not use_core_manipulations.get_stopping_criteria
|
||
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||
|
||
|
||
class GenerationResult:
|
||
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
|
||
|
||
def __init__(
|
||
self,
|
||
model: InferenceModel,
|
||
out_batches: list,
|
||
prompt: list,
|
||
# Controls if generate() does it's looping thing. This should only be
|
||
# done for HF models that use that StoppingCondition
|
||
is_whole_generation: bool,
|
||
# Controls if we should trim output by prompt length
|
||
output_includes_prompt: bool = False,
|
||
# Lazy filter to cut off extra lines where we can't manipulate
|
||
# probabilities
|
||
single_line: bool = False,
|
||
):
|
||
# Shave prompt off of encoded response when needed (HF). Decoded does
|
||
# not return prompt.
|
||
if output_includes_prompt:
|
||
self.encoded = out_batches[:, len(prompt) :]
|
||
else:
|
||
self.encoded = out_batches
|
||
|
||
self.prompt = prompt
|
||
self.is_whole_generation = is_whole_generation
|
||
|
||
self.decoded = [
|
||
utils.decodenewlines(model.tokenizer.decode(enc)) for enc in self.encoded
|
||
]
|
||
|
||
if single_line:
|
||
self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
|
||
self.encoded = np.array(model.tokenizer(self.decoded).input_ids)
|
||
|
||
|
||
class GenerationSettings:
|
||
"""Structure for holding temporarily overwritten settings."""
|
||
|
||
def __init__(self, **overrides) -> None:
|
||
for setting in [
|
||
"temp",
|
||
"top_p",
|
||
"top_k",
|
||
"tfs",
|
||
"typical",
|
||
"top_a",
|
||
"rep_pen",
|
||
"rep_pen_slope",
|
||
"rep_pen_range",
|
||
"sampler_order",
|
||
]:
|
||
setattr(
|
||
self,
|
||
setting,
|
||
overrides.get(setting, getattr(utils.koboldai_vars, setting)),
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class ModelCapabilities:
|
||
embedding_manipulation: bool = False
|
||
post_token_hooks: bool = False
|
||
stopper_hooks: bool = False
|
||
# TODO: Support non-live probabilities from APIs
|
||
post_token_probs: bool = False
|
||
|
||
|
||
class InferenceModel:
|
||
"""Root class for all models."""
|
||
|
||
def __init__(self) -> None:
|
||
self.gen_state = {}
|
||
self.post_token_hooks = []
|
||
self.stopper_hooks = []
|
||
self.tokenizer = None
|
||
self.capabilties = ModelCapabilities()
|
||
|
||
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||
|
||
self._load(save_model=save_model, initial_load=initial_load)
|
||
self._post_load()
|
||
|
||
global current_model
|
||
current_model = self
|
||
|
||
print(self.raw_generate("Hi guys,", 20).__dict__)
|
||
|
||
def _post_load(self) -> None:
|
||
"""Post load hook. Called after `_load()`."""
|
||
|
||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||
"""Main load method. All logic related to loading the model onto the
|
||
selected device(s) and preparing it for inference should be implemented here."""
|
||
raise NotImplementedError
|
||
|
||
def _get_tokenizer(self, location: str) -> AutoTokenizer:
|
||
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
|
||
|
||
Args:
|
||
location (str): Either a local model directory path or a HuggingFace model ID.
|
||
|
||
Returns:
|
||
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
||
"""
|
||
if utils.koboldai_vars.model_type == "xglm":
|
||
# Default to </s> newline mode if using XGLM
|
||
utils.koboldai_vars.newlinemode = "s"
|
||
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||
utils.koboldai_vars.newlinemode = "ns"
|
||
|
||
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
|
||
|
||
suppliers = [
|
||
# Fast tokenizer disabled by default as per HF docs:
|
||
# > Note: Make sure to pass use_fast=False when loading
|
||
# OPT’s tokenizer with AutoTokenizer to get the correct
|
||
# tokenizer.
|
||
lambda: AutoTokenizer.from_pretrained(
|
||
location, use_fast=False, **std_kwargs
|
||
),
|
||
lambda: AutoTokenizer.from_pretrained(location, **std_kwargs),
|
||
# Fallback to GPT2Tokenizer
|
||
lambda: GPT2Tokenizer.from_pretrained(location, **std_kwargs),
|
||
lambda: GPT2Tokenizer.from_pretrained("gpt2", **std_kwargs),
|
||
]
|
||
|
||
for i, try_get_tokenizer in enumerate(suppliers):
|
||
try:
|
||
return try_get_tokenizer()
|
||
except:
|
||
# If we error on each attempt, raise the last one
|
||
if i == len(suppliers) - 1:
|
||
raise
|
||
|
||
def core_generate(
|
||
self,
|
||
text: list,
|
||
found_entries: set,
|
||
):
|
||
"""Generate story text. Heavily tied to story-specific parameters; if
|
||
you are making a new generation-based feature, consider `generate_raw()`.
|
||
|
||
Args:
|
||
text (list): Encoded input tokens
|
||
found_entries (set): Entries found for Dynamic WI
|
||
|
||
Raises:
|
||
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
|
||
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
|
||
"""
|
||
|
||
start_time = time.time()
|
||
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||
logger.debug(
|
||
"core_generate: torch.tensor time {}s".format(time.time() - start_time)
|
||
)
|
||
|
||
start_time = time.time()
|
||
if utils.koboldai_vars.is_model_torch():
|
||
# Torch stuff
|
||
if utils.koboldai_vars.full_determinism:
|
||
torch.manual_seed(utils.koboldai_vars.seed)
|
||
|
||
if utils.koboldai_vars.sp is not None:
|
||
assert self.capabilties.embedding_manipulation
|
||
soft_tokens = torch.arange(
|
||
self.model.config.vocab_size,
|
||
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
||
)
|
||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||
elif utils.koboldai_vars.use_colab_tpu:
|
||
if utils.koboldai_vars.full_determinism:
|
||
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
|
||
|
||
logger.debug(
|
||
"core_generate: Model Setup (SP, etc) time {}s".format(
|
||
time.time() - start_time
|
||
)
|
||
)
|
||
|
||
if (
|
||
gen_in.shape[-1] + utils.koboldai_vars.genamt
|
||
> utils.koboldai_vars.max_length
|
||
):
|
||
logger.error("gen_in.shape[-1]: {}".format(gen_in.shape[-1]))
|
||
logger.error(
|
||
"utils.koboldai_vars.genamt: {}".format(utils.koboldai_vars.genamt)
|
||
)
|
||
logger.error(
|
||
"utils.koboldai_vars.max_length: {}".format(
|
||
utils.koboldai_vars.max_length
|
||
)
|
||
)
|
||
assert (
|
||
gen_in.shape[-1] + utils.koboldai_vars.genamt
|
||
<= utils.koboldai_vars.max_length
|
||
)
|
||
|
||
start_time = time.time()
|
||
gen_in = gen_in.to(utils.get_auxilary_device())
|
||
|
||
logger.debug(
|
||
"core_generate: gen_in to device time {}s".format(time.time() - start_time)
|
||
)
|
||
start_time = time.time()
|
||
|
||
found_entries = found_entries or set()
|
||
|
||
self.gen_state["wi_scanner_excluded_keys"] = found_entries
|
||
|
||
utils.koboldai_vars._prompt = utils.koboldai_vars.prompt
|
||
|
||
with torch.no_grad():
|
||
already_generated = 0
|
||
numseqs = utils.koboldai_vars.numseqs
|
||
total_gens = None
|
||
|
||
for i in range(
|
||
utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1
|
||
):
|
||
while True:
|
||
# The reason this is a loop is due to how Dynamic WI works. We
|
||
# cannot simply add the WI to the context mid-generation, so we
|
||
# stop early, and then insert WI, then continue generating. That
|
||
# stopping and continuing is this loop.
|
||
|
||
start_time = time.time()
|
||
result = self.raw_generate(
|
||
gen_in[0],
|
||
max_new=utils.koboldai_vars.genamt,
|
||
do_streaming=utils.koboldai_vars.output_streaming,
|
||
do_dynamic_wi=utils.koboldai_vars.dynamicscan,
|
||
batch_count=numseqs
|
||
if not utils.koboldai_vars.alt_multi_gen
|
||
else 1,
|
||
# Real max length is handled by CoreStopper.
|
||
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
|
||
is_core=True,
|
||
)
|
||
logger.debug(
|
||
"core_generate: run raw_generate pass {} {}s".format(
|
||
already_generated, time.time() - start_time
|
||
)
|
||
)
|
||
|
||
genout = result.encoded
|
||
|
||
already_generated += len(genout[0])
|
||
|
||
try:
|
||
assert (
|
||
already_generated
|
||
<= utils.koboldai_vars.genamt * utils.koboldai_vars.numseqs
|
||
if utils.koboldai_vars.alt_multi_gen
|
||
else 1
|
||
)
|
||
except AssertionError:
|
||
print("AlreadyGenerated", already_generated)
|
||
print("genamt", utils.koboldai_vars.genamt)
|
||
raise
|
||
|
||
if result.is_whole_generation:
|
||
break
|
||
|
||
# Generation stopped; why?
|
||
# If we have been told to halt, we have reached our target token
|
||
# amount (controlled by halt), or Dynamic WI has not told us to
|
||
# stop temporarily to insert WI, we can assume that we are done
|
||
# generating. We shall break.
|
||
if (
|
||
self.gen_state["halt"]
|
||
or not self.gen_state["regeneration_required"]
|
||
):
|
||
break
|
||
|
||
# Now we are doing stuff for Dynamic WI.
|
||
assert genout.ndim >= 2
|
||
assert genout.shape[0] == utils.koboldai_vars.numseqs
|
||
|
||
if (
|
||
utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||
and utils.koboldai_vars.generated_tkns
|
||
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||
):
|
||
raise RuntimeError(
|
||
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||
)
|
||
|
||
if already_generated != utils.koboldai_vars.generated_tkns:
|
||
print("already_generated: {}".format(already_generated))
|
||
print(
|
||
"generated_tkns: {}".format(
|
||
utils.koboldai_vars.generated_tkns
|
||
)
|
||
)
|
||
raise RuntimeError("WI scanning error")
|
||
|
||
for r in range(utils.koboldai_vars.numseqs):
|
||
for c in range(already_generated):
|
||
assert (
|
||
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
|
||
c + 1
|
||
]
|
||
is not None
|
||
)
|
||
genout[r][
|
||
genout.shape[-1] - already_generated + c
|
||
] = utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
|
||
c + 1
|
||
]
|
||
|
||
encoded = []
|
||
|
||
for i in range(utils.koboldai_vars.numseqs):
|
||
txt = utils.decodenewlines(
|
||
self.tokenizer.decode(genout[i, -already_generated:])
|
||
)
|
||
# winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=utils.koboldai_vars.actions)
|
||
# txt, _, _ = calcsubmitbudget(len(utils.koboldai_vars.actions), winfo, mem, anotetxt, utils.koboldai_vars.actions, submission=txt)
|
||
txt, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
|
||
submitted_text=txt, send_context=False
|
||
)
|
||
found_entries[i].update(_found_entries)
|
||
encoded.append(
|
||
torch.tensor(txt, dtype=torch.long, device=genout.device)
|
||
)
|
||
|
||
max_length = len(max(encoded, key=len))
|
||
encoded = torch.stack(
|
||
tuple(
|
||
torch.nn.functional.pad(
|
||
e,
|
||
(max_length - len(e), 0),
|
||
value=self.model.config.pad_token_id
|
||
or self.model.config.eos_token_id,
|
||
)
|
||
for e in encoded
|
||
)
|
||
)
|
||
genout = torch.cat(
|
||
(
|
||
encoded,
|
||
genout[..., -already_generated:],
|
||
),
|
||
dim=-1,
|
||
)
|
||
|
||
if utils.koboldai_vars.sp is not None:
|
||
soft_tokens = torch.arange(
|
||
self.model.config.vocab_size,
|
||
self.model.config.vocab_size
|
||
+ utils.koboldai_vars.sp.shape[0],
|
||
device=genout.device,
|
||
)
|
||
genout = torch.cat(
|
||
(soft_tokens.tile(utils.koboldai_vars.numseqs, 1), genout),
|
||
dim=-1,
|
||
)
|
||
|
||
assert (
|
||
genout.shape[-1]
|
||
+ utils.koboldai_vars.genamt
|
||
- already_generated
|
||
<= utils.koboldai_vars.max_length
|
||
)
|
||
gen_in = genout
|
||
numseqs = 1
|
||
if total_gens is None:
|
||
total_gens = genout
|
||
else:
|
||
total_gens = torch.cat((total_gens, genout))
|
||
|
||
return total_gens, already_generated
|
||
|
||
def _raw_generate(
|
||
self,
|
||
prompt_tokens: Union[List[int], torch.Tensor],
|
||
max_new: int,
|
||
gen_settings: GenerationSettings,
|
||
single_line: bool = False,
|
||
batch_count: int = 1,
|
||
) -> GenerationResult:
|
||
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
||
|
||
Args:
|
||
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
|
||
max_new (int): Maximum amount of new tokens to generate
|
||
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||
single_line (bool, optional): Generate one line only. Defaults to False.
|
||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||
|
||
Returns:
|
||
GenerationResult: The model's output
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def raw_generate(
|
||
self,
|
||
# prompt is either a string (text) or a list (token ids)
|
||
prompt: Union[str, list, np.ndarray],
|
||
max_new: int,
|
||
do_streaming: bool = False,
|
||
do_dynamic_wi: bool = False,
|
||
batch_count: int = 1,
|
||
bypass_hf_maxlength: bool = False,
|
||
generation_settings: Optional[dict] = None,
|
||
is_core: bool = False,
|
||
single_line: bool = False,
|
||
found_entries: set = (),
|
||
) -> GenerationResult:
|
||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||
|
||
Args:
|
||
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
|
||
max_new (int): Maximum amount of new tokens to generate
|
||
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
|
||
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
|
||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
|
||
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
|
||
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
|
||
single_line (bool, optional): Generate one line only.. Defaults to False.
|
||
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
|
||
|
||
Raises:
|
||
ValueError: If prompt type is weird
|
||
NotImplementedError: If model is ReadOnly
|
||
|
||
Returns:
|
||
GenerationResult: The model's output
|
||
"""
|
||
# TODO: Support singleline outside of torch
|
||
|
||
self.gen_state["do_streaming"] = do_streaming
|
||
self.gen_state["do_dynamic_wi"] = do_dynamic_wi
|
||
|
||
# Dynamic WI depends on this!!! This is a main gen call.
|
||
self.gen_state["stop_at_genamt"] = do_dynamic_wi
|
||
|
||
# Makes stopping criteria hook happy
|
||
self.gen_state["wi_scanner_excluded_keys"] = self.gen_state.get(
|
||
"wi_scanner_excluded_keys", set()
|
||
)
|
||
|
||
utils.koboldai_vars.inference_config.do_core = is_core
|
||
gen_settings = GenerationSettings(*(generation_settings or {}))
|
||
|
||
if isinstance(prompt, torch.Tensor):
|
||
prompt_tokens = prompt.cpu().numpy()
|
||
elif isinstance(prompt, list):
|
||
prompt_tokens = np.array(prompt)
|
||
elif isinstance(prompt, str):
|
||
prompt_tokens = np.array(self.tokenizer.encode(prompt))
|
||
else:
|
||
raise ValueError(f"Prompt is {type(prompt)}. Not a fan!")
|
||
|
||
assert isinstance(prompt_tokens, np.ndarray)
|
||
assert len(prompt_tokens.shape) == 1
|
||
|
||
if utils.koboldai_vars.model == "ReadOnly":
|
||
raise NotImplementedError("No loaded model")
|
||
|
||
time_start = time.time()
|
||
|
||
with use_core_manipulations():
|
||
result = self._raw_generate(
|
||
prompt_tokens=prompt_tokens,
|
||
max_new=max_new,
|
||
batch_count=batch_count,
|
||
gen_settings=gen_settings,
|
||
single_line=single_line,
|
||
)
|
||
|
||
time_end = round(time.time() - time_start, 2)
|
||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||
|
||
if not utils.koboldai_vars.quiet:
|
||
logger.info(
|
||
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
|
||
)
|
||
|
||
return result
|
||
|
||
def generate(
|
||
self,
|
||
prompt_tokens: Union[List[int], torch.Tensor],
|
||
max_new_tokens: int,
|
||
do_streaming: bool = False,
|
||
do_dynamic_wi: bool = False,
|
||
single_line: bool = False,
|
||
batch_count: int = 1,
|
||
) -> torch.Tensor:
|
||
raise NotImplementedError
|
||
|
||
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
|
||
for hook in self.post_token_hooks:
|
||
hook(self, input_ids)
|