Model: Documentation part 1

This commit is contained in:
somebody
2023-02-28 19:26:25 -06:00
parent ef1155291f
commit 225dcf1a0a
3 changed files with 110 additions and 33 deletions

View File

@@ -3498,7 +3498,7 @@ def apiactionsubmit_generate(txt, minimum, maximum):
torch.cuda.empty_cache()
# Submit input text to generator
_genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, set())
_genout, already_generated = tpool.execute(model.core_generate, txt, set())
genout = [utils.applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
@@ -3939,7 +3939,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator
try:
start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, found_entries)
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):

100
model.py
View File

@@ -1,4 +1,7 @@
# Before merge: please make sure to fix any TODOB4MERGE comments
# Before merge:
# - Fix Lua
# - Check if probabilities work
# - Fix any TODOB4MERGE comments
from __future__ import annotations
import bisect
@@ -79,14 +82,20 @@ class OpenAIAPIError(Exception):
class HordeException(Exception):
"""To be used for errors on server side of the Horde."""
pass
class ColabException(Exception):
"""To be used for errors when using the Colab API as an interface."""
pass
class APIException(Exception):
"""To be used for errors when using the Kobold API as an interface."""
pass
@@ -424,24 +433,6 @@ def patch_transformers_generation() -> None:
cls.__call__ = new_call
# TODO: Make samplers generic
# dynamic_processor_wrap(
# AdvancedRepetitionPenaltyLogitsProcessor,
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
# cond=lambda x: x[0] != 1.0,
# )
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
# )
# dynamic_processor_wrap(
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
# )
# Allow bad words filter to ban <|endoftext|> token
import transformers.generation.logits_process
@@ -463,6 +454,8 @@ def patch_transformers() -> None:
class GenerationResult:
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
def __init__(
self,
model: InferenceModel,
@@ -506,6 +499,8 @@ class ModelCapabilities:
class InferenceModel:
"""Root class for all models."""
def __init__(self) -> None:
self.gen_state = {}
self.post_token_hooks = []
@@ -514,7 +509,7 @@ class InferenceModel:
self.capabilties = ModelCapabilities()
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""Main load function. Do not override this. Override _load() instead."""
"""User-facing load function. Do not override this; try `_load()` instead."""
self._load(save_model=save_model, initial_load=initial_load)
self._post_load()
@@ -525,12 +520,23 @@ class InferenceModel:
print(self.raw_generate("Hi guys,", 20).__dict__)
def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""
pass
def _load(self, save_model: bool, initial_load: bool) -> None:
"""Main load method. All logic related to loading the model onto the
selected device(s) and preparing it for inference should be implemented here."""
raise NotImplementedError
def _get_tokenizer(self, location: str):
def _get_tokenizer(self, location: str) -> AutoTokenizer:
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
Args:
location (str): Either a local model directory path or a HuggingFace model ID.
Returns:
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
"""
if utils.koboldai_vars.model_type == "xglm":
# Default to </s> newline mode if using XGLM
utils.koboldai_vars.newlinemode = "s"
@@ -565,13 +571,19 @@ class InferenceModel:
def core_generate(
self,
text: list,
_min: int,
_max: int,
found_entries: set,
is_core: bool = False,
):
# This generation function is tangled with koboldai_vars intentionally. It
# is meant for the story and nothing else.
"""Generate story text. Heavily tied to story-specific parameters; if
you are making a new generation-based feature, consider `generate_raw()`.
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
"""
start_time = time.time()
gen_in = torch.tensor(text, dtype=torch.long)[None]
@@ -804,6 +816,18 @@ class InferenceModel:
single_line: bool = False,
batch_count: int = 1,
) -> GenerationResult:
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
Args:
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
max_new (int): Maximum amount of new tokens to generate
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
Returns:
GenerationResult: The model's output
"""
raise NotImplementedError
def raw_generate(
@@ -820,7 +844,27 @@ class InferenceModel:
single_line: bool = False,
found_entries: set = (),
) -> GenerationResult:
"""A wrapper around _raw_generate() that handles timing and some other minute stuff."""
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
Args:
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
max_new (int): Maximum amount of new tokens to generate
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
Raises:
ValueError: If prompt type is weird
NotImplementedError: If model is ReadOnly
Returns:
GenerationResult: The model's output
"""
# TODO: Support singleline outside of torch
self.gen_state["do_streaming"] = do_streaming
@@ -1101,10 +1145,8 @@ class HFMTJInferenceModel(HFInferenceModel):
initial_load=initial_load,
logger=logger,
**self.model_config.to_dict()
# **utils.koboldai_vars.modelconfig,
)
# tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
utils.koboldai_vars.modeldim = int(
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
)

View File

@@ -68,6 +68,25 @@ def update_settings():
class Warper:
"""The backbone for implementing code which manipulates token logits.
All Warpers should be singletons defined in the warpers.py file.
To make a new warper/sampler:
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
due to how it does JIT compilation of functions (from what I gather).
These `static` methods are very picky about what you can and can't do
with data at runtime and thus sometimes need to be implemented
differently than the `dynamic` methods, which are more like the Torch
methods.
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
- Add it to the UI/sampler_order.
To implement the samplers on a new model type/interface, assuming you're
dealing with Torch tensors, iterate over Warpers from sampler_order using
`Warper.from_id()`, and apply changes with the `torch()` methods.
"""
@staticmethod
def from_id(warper_id: int) -> Warper:
return {
@@ -80,6 +99,22 @@ class Warper:
6: RepetitionPenalty,
}[warper_id]
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Please override `torch()`.")
@classmethod
def jax_dynamic(cls, scores: np.array) -> np.array:
raise NotImplementedError("Please override `jax_dynamic()`.")
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
raise NotImplementedError("Please override `jax_static()`.")
@classmethod
def value_is_valid(cls) -> bool:
raise NotImplementedError("Please override `value_is_valid()`.")
class Temperature(Warper):
"""Temperature (just divide the logits by the temperature)"""
@@ -534,8 +569,8 @@ class RepetitionPenalty(Warper):
generated_index,
) -> jnp.array:
"""
This gets called by generate_loop_fn to apply repetition penalty
to the 1D array logits using the provided 1D array of tokens to penalize
This gets called to apply repetition penalty to the 1D array logits
using the provided 1D array of tokens to penalize
"""
rpslope = jnp.int32(cls.rep_pen_slope)
rprange = jnp.int32(cls.rep_pen_range)