Model: Documentation part 1

This commit is contained in:
somebody
2023-02-28 19:26:25 -06:00
parent ef1155291f
commit 225dcf1a0a
3 changed files with 110 additions and 33 deletions

View File

@@ -3498,7 +3498,7 @@ def apiactionsubmit_generate(txt, minimum, maximum):
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Submit input text to generator # Submit input text to generator
_genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, set()) _genout, already_generated = tpool.execute(model.core_generate, txt, set())
genout = [utils.applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout] genout = [utils.applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
@@ -3939,7 +3939,7 @@ def generate(txt, minimum, maximum, found_entries=None):
# Submit input text to generator # Submit input text to generator
try: try:
start_time = time.time() start_time = time.time()
genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, found_entries) genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time)) logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):

100
model.py
View File

@@ -1,4 +1,7 @@
# Before merge: please make sure to fix any TODOB4MERGE comments # Before merge:
# - Fix Lua
# - Check if probabilities work
# - Fix any TODOB4MERGE comments
from __future__ import annotations from __future__ import annotations
import bisect import bisect
@@ -79,14 +82,20 @@ class OpenAIAPIError(Exception):
class HordeException(Exception): class HordeException(Exception):
"""To be used for errors on server side of the Horde."""
pass pass
class ColabException(Exception): class ColabException(Exception):
"""To be used for errors when using the Colab API as an interface."""
pass pass
class APIException(Exception): class APIException(Exception):
"""To be used for errors when using the Kobold API as an interface."""
pass pass
@@ -424,24 +433,6 @@ def patch_transformers_generation() -> None:
cls.__call__ = new_call cls.__call__ = new_call
# TODO: Make samplers generic
# dynamic_processor_wrap(
# AdvancedRepetitionPenaltyLogitsProcessor,
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
# cond=lambda x: x[0] != 1.0,
# )
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
# dynamic_processor_wrap(
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
# )
# dynamic_processor_wrap(
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
# )
# Allow bad words filter to ban <|endoftext|> token # Allow bad words filter to ban <|endoftext|> token
import transformers.generation.logits_process import transformers.generation.logits_process
@@ -463,6 +454,8 @@ def patch_transformers() -> None:
class GenerationResult: class GenerationResult:
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
def __init__( def __init__(
self, self,
model: InferenceModel, model: InferenceModel,
@@ -506,6 +499,8 @@ class ModelCapabilities:
class InferenceModel: class InferenceModel:
"""Root class for all models."""
def __init__(self) -> None: def __init__(self) -> None:
self.gen_state = {} self.gen_state = {}
self.post_token_hooks = [] self.post_token_hooks = []
@@ -514,7 +509,7 @@ class InferenceModel:
self.capabilties = ModelCapabilities() self.capabilties = ModelCapabilities()
def load(self, save_model: bool = False, initial_load: bool = False) -> None: def load(self, save_model: bool = False, initial_load: bool = False) -> None:
"""Main load function. Do not override this. Override _load() instead.""" """User-facing load function. Do not override this; try `_load()` instead."""
self._load(save_model=save_model, initial_load=initial_load) self._load(save_model=save_model, initial_load=initial_load)
self._post_load() self._post_load()
@@ -525,12 +520,23 @@ class InferenceModel:
print(self.raw_generate("Hi guys,", 20).__dict__) print(self.raw_generate("Hi guys,", 20).__dict__)
def _post_load(self) -> None: def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""
pass pass
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
"""Main load method. All logic related to loading the model onto the
selected device(s) and preparing it for inference should be implemented here."""
raise NotImplementedError raise NotImplementedError
def _get_tokenizer(self, location: str): def _get_tokenizer(self, location: str) -> AutoTokenizer:
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
Args:
location (str): Either a local model directory path or a HuggingFace model ID.
Returns:
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
"""
if utils.koboldai_vars.model_type == "xglm": if utils.koboldai_vars.model_type == "xglm":
# Default to </s> newline mode if using XGLM # Default to </s> newline mode if using XGLM
utils.koboldai_vars.newlinemode = "s" utils.koboldai_vars.newlinemode = "s"
@@ -565,13 +571,19 @@ class InferenceModel:
def core_generate( def core_generate(
self, self,
text: list, text: list,
_min: int,
_max: int,
found_entries: set, found_entries: set,
is_core: bool = False,
): ):
# This generation function is tangled with koboldai_vars intentionally. It """Generate story text. Heavily tied to story-specific parameters; if
# is meant for the story and nothing else. you are making a new generation-based feature, consider `generate_raw()`.
Args:
text (list): Encoded input tokens
found_entries (set): Entries found for Dynamic WI
Raises:
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
"""
start_time = time.time() start_time = time.time()
gen_in = torch.tensor(text, dtype=torch.long)[None] gen_in = torch.tensor(text, dtype=torch.long)[None]
@@ -804,6 +816,18 @@ class InferenceModel:
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
) -> GenerationResult: ) -> GenerationResult:
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
Args:
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
max_new (int): Maximum amount of new tokens to generate
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
single_line (bool, optional): Generate one line only. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
Returns:
GenerationResult: The model's output
"""
raise NotImplementedError raise NotImplementedError
def raw_generate( def raw_generate(
@@ -820,7 +844,27 @@ class InferenceModel:
single_line: bool = False, single_line: bool = False,
found_entries: set = (), found_entries: set = (),
) -> GenerationResult: ) -> GenerationResult:
"""A wrapper around _raw_generate() that handles timing and some other minute stuff.""" """A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
Args:
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
max_new (int): Maximum amount of new tokens to generate
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
single_line (bool, optional): Generate one line only.. Defaults to False.
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
Raises:
ValueError: If prompt type is weird
NotImplementedError: If model is ReadOnly
Returns:
GenerationResult: The model's output
"""
# TODO: Support singleline outside of torch # TODO: Support singleline outside of torch
self.gen_state["do_streaming"] = do_streaming self.gen_state["do_streaming"] = do_streaming
@@ -1101,10 +1145,8 @@ class HFMTJInferenceModel(HFInferenceModel):
initial_load=initial_load, initial_load=initial_load,
logger=logger, logger=logger,
**self.model_config.to_dict() **self.model_config.to_dict()
# **utils.koboldai_vars.modelconfig,
) )
# tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
utils.koboldai_vars.modeldim = int( utils.koboldai_vars.modeldim = int(
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]) tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
) )

View File

@@ -68,6 +68,25 @@ def update_settings():
class Warper: class Warper:
"""The backbone for implementing code which manipulates token logits.
All Warpers should be singletons defined in the warpers.py file.
To make a new warper/sampler:
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
due to how it does JIT compilation of functions (from what I gather).
These `static` methods are very picky about what you can and can't do
with data at runtime and thus sometimes need to be implemented
differently than the `dynamic` methods, which are more like the Torch
methods.
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
- Add it to the UI/sampler_order.
To implement the samplers on a new model type/interface, assuming you're
dealing with Torch tensors, iterate over Warpers from sampler_order using
`Warper.from_id()`, and apply changes with the `torch()` methods.
"""
@staticmethod @staticmethod
def from_id(warper_id: int) -> Warper: def from_id(warper_id: int) -> Warper:
return { return {
@@ -80,6 +99,22 @@ class Warper:
6: RepetitionPenalty, 6: RepetitionPenalty,
}[warper_id] }[warper_id]
@classmethod
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Please override `torch()`.")
@classmethod
def jax_dynamic(cls, scores: np.array) -> np.array:
raise NotImplementedError("Please override `jax_dynamic()`.")
@classmethod
def jax_static(cls, scores: jnp.array) -> jnp.array:
raise NotImplementedError("Please override `jax_static()`.")
@classmethod
def value_is_valid(cls) -> bool:
raise NotImplementedError("Please override `value_is_valid()`.")
class Temperature(Warper): class Temperature(Warper):
"""Temperature (just divide the logits by the temperature)""" """Temperature (just divide the logits by the temperature)"""
@@ -534,8 +569,8 @@ class RepetitionPenalty(Warper):
generated_index, generated_index,
) -> jnp.array: ) -> jnp.array:
""" """
This gets called by generate_loop_fn to apply repetition penalty This gets called to apply repetition penalty to the 1D array logits
to the 1D array logits using the provided 1D array of tokens to penalize using the provided 1D array of tokens to penalize
""" """
rpslope = jnp.int32(cls.rep_pen_slope) rpslope = jnp.int32(cls.rep_pen_slope)
rprange = jnp.int32(cls.rep_pen_range) rprange = jnp.int32(cls.rep_pen_range)