mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Documentation part 1
This commit is contained in:
@@ -3498,7 +3498,7 @@ def apiactionsubmit_generate(txt, minimum, maximum):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Submit input text to generator
|
||||
_genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, set())
|
||||
_genout, already_generated = tpool.execute(model.core_generate, txt, set())
|
||||
|
||||
genout = [utils.applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
|
||||
|
||||
@@ -3939,7 +3939,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
# Submit input text to generator
|
||||
try:
|
||||
start_time = time.time()
|
||||
genout, already_generated = tpool.execute(model.core_generate, txt, minimum, maximum, found_entries)
|
||||
genout, already_generated = tpool.execute(model.core_generate, txt, found_entries)
|
||||
logger.debug("Generate: core_generate time {}s".format(time.time()-start_time))
|
||||
except Exception as e:
|
||||
if(issubclass(type(e), lupa.LuaError)):
|
||||
|
100
model.py
100
model.py
@@ -1,4 +1,7 @@
|
||||
# Before merge: please make sure to fix any TODOB4MERGE comments
|
||||
# Before merge:
|
||||
# - Fix Lua
|
||||
# - Check if probabilities work
|
||||
# - Fix any TODOB4MERGE comments
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
@@ -79,14 +82,20 @@ class OpenAIAPIError(Exception):
|
||||
|
||||
|
||||
class HordeException(Exception):
|
||||
"""To be used for errors on server side of the Horde."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ColabException(Exception):
|
||||
"""To be used for errors when using the Colab API as an interface."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
"""To be used for errors when using the Kobold API as an interface."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -424,24 +433,6 @@ def patch_transformers_generation() -> None:
|
||||
|
||||
cls.__call__ = new_call
|
||||
|
||||
# TODO: Make samplers generic
|
||||
# dynamic_processor_wrap(
|
||||
# AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
# ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||
# ("rep_pen", "rep_pen_slope", "rep_pen_range", "use_alt_rep_pen"),
|
||||
# cond=lambda x: x[0] != 1.0,
|
||||
# )
|
||||
# dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
|
||||
# dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
|
||||
# dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
|
||||
# dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
|
||||
# dynamic_processor_wrap(
|
||||
# TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0
|
||||
# )
|
||||
# dynamic_processor_wrap(
|
||||
# TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0
|
||||
# )
|
||||
|
||||
# Allow bad words filter to ban <|endoftext|> token
|
||||
import transformers.generation.logits_process
|
||||
|
||||
@@ -463,6 +454,8 @@ def patch_transformers() -> None:
|
||||
|
||||
|
||||
class GenerationResult:
|
||||
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: InferenceModel,
|
||||
@@ -506,6 +499,8 @@ class ModelCapabilities:
|
||||
|
||||
|
||||
class InferenceModel:
|
||||
"""Root class for all models."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.gen_state = {}
|
||||
self.post_token_hooks = []
|
||||
@@ -514,7 +509,7 @@ class InferenceModel:
|
||||
self.capabilties = ModelCapabilities()
|
||||
|
||||
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||||
"""Main load function. Do not override this. Override _load() instead."""
|
||||
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||||
|
||||
self._load(save_model=save_model, initial_load=initial_load)
|
||||
self._post_load()
|
||||
@@ -525,12 +520,23 @@ class InferenceModel:
|
||||
print(self.raw_generate("Hi guys,", 20).__dict__)
|
||||
|
||||
def _post_load(self) -> None:
|
||||
"""Post load hook. Called after `_load()`."""
|
||||
pass
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
"""Main load method. All logic related to loading the model onto the
|
||||
selected device(s) and preparing it for inference should be implemented here."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_tokenizer(self, location: str):
|
||||
def _get_tokenizer(self, location: str) -> AutoTokenizer:
|
||||
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
|
||||
|
||||
Args:
|
||||
location (str): Either a local model directory path or a HuggingFace model ID.
|
||||
|
||||
Returns:
|
||||
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
||||
"""
|
||||
if utils.koboldai_vars.model_type == "xglm":
|
||||
# Default to </s> newline mode if using XGLM
|
||||
utils.koboldai_vars.newlinemode = "s"
|
||||
@@ -565,13 +571,19 @@ class InferenceModel:
|
||||
def core_generate(
|
||||
self,
|
||||
text: list,
|
||||
_min: int,
|
||||
_max: int,
|
||||
found_entries: set,
|
||||
is_core: bool = False,
|
||||
):
|
||||
# This generation function is tangled with koboldai_vars intentionally. It
|
||||
# is meant for the story and nothing else.
|
||||
"""Generate story text. Heavily tied to story-specific parameters; if
|
||||
you are making a new generation-based feature, consider `generate_raw()`.
|
||||
|
||||
Args:
|
||||
text (list): Encoded input tokens
|
||||
found_entries (set): Entries found for Dynamic WI
|
||||
|
||||
Raises:
|
||||
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
|
||||
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
|
||||
"""
|
||||
|
||||
start_time = time.time()
|
||||
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||||
@@ -804,6 +816,18 @@ class InferenceModel:
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
) -> GenerationResult:
|
||||
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
||||
|
||||
Args:
|
||||
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
|
||||
max_new (int): Maximum amount of new tokens to generate
|
||||
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||||
single_line (bool, optional): Generate one line only. Defaults to False.
|
||||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
GenerationResult: The model's output
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def raw_generate(
|
||||
@@ -820,7 +844,27 @@ class InferenceModel:
|
||||
single_line: bool = False,
|
||||
found_entries: set = (),
|
||||
) -> GenerationResult:
|
||||
"""A wrapper around _raw_generate() that handles timing and some other minute stuff."""
|
||||
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
|
||||
max_new (int): Maximum amount of new tokens to generate
|
||||
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
|
||||
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
|
||||
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
|
||||
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
|
||||
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
|
||||
single_line (bool, optional): Generate one line only.. Defaults to False.
|
||||
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt type is weird
|
||||
NotImplementedError: If model is ReadOnly
|
||||
|
||||
Returns:
|
||||
GenerationResult: The model's output
|
||||
"""
|
||||
# TODO: Support singleline outside of torch
|
||||
|
||||
self.gen_state["do_streaming"] = do_streaming
|
||||
@@ -1101,10 +1145,8 @@ class HFMTJInferenceModel(HFInferenceModel):
|
||||
initial_load=initial_load,
|
||||
logger=logger,
|
||||
**self.model_config.to_dict()
|
||||
# **utils.koboldai_vars.modelconfig,
|
||||
)
|
||||
|
||||
# tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
|
||||
utils.koboldai_vars.modeldim = int(
|
||||
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
|
||||
)
|
||||
|
39
warpers.py
39
warpers.py
@@ -68,6 +68,25 @@ def update_settings():
|
||||
|
||||
|
||||
class Warper:
|
||||
"""The backbone for implementing code which manipulates token logits.
|
||||
All Warpers should be singletons defined in the warpers.py file.
|
||||
|
||||
To make a new warper/sampler:
|
||||
- Create your class, implementing `torch()`, `jax_dynamic`, `jax_static`,
|
||||
and `value_is_valid()`. Dynamic and static methods are seperated for Jax
|
||||
due to how it does JIT compilation of functions (from what I gather).
|
||||
These `static` methods are very picky about what you can and can't do
|
||||
with data at runtime and thus sometimes need to be implemented
|
||||
differently than the `dynamic` methods, which are more like the Torch
|
||||
methods.
|
||||
- Add it to Warper.from_id and tpu_mtj_backend.kobold_sample_static.
|
||||
- Add it to the UI/sampler_order.
|
||||
|
||||
To implement the samplers on a new model type/interface, assuming you're
|
||||
dealing with Torch tensors, iterate over Warpers from sampler_order using
|
||||
`Warper.from_id()`, and apply changes with the `torch()` methods.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_id(warper_id: int) -> Warper:
|
||||
return {
|
||||
@@ -80,6 +99,22 @@ class Warper:
|
||||
6: RepetitionPenalty,
|
||||
}[warper_id]
|
||||
|
||||
@classmethod
|
||||
def torch(cls, scores: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("Please override `torch()`.")
|
||||
|
||||
@classmethod
|
||||
def jax_dynamic(cls, scores: np.array) -> np.array:
|
||||
raise NotImplementedError("Please override `jax_dynamic()`.")
|
||||
|
||||
@classmethod
|
||||
def jax_static(cls, scores: jnp.array) -> jnp.array:
|
||||
raise NotImplementedError("Please override `jax_static()`.")
|
||||
|
||||
@classmethod
|
||||
def value_is_valid(cls) -> bool:
|
||||
raise NotImplementedError("Please override `value_is_valid()`.")
|
||||
|
||||
|
||||
class Temperature(Warper):
|
||||
"""Temperature (just divide the logits by the temperature)"""
|
||||
@@ -534,8 +569,8 @@ class RepetitionPenalty(Warper):
|
||||
generated_index,
|
||||
) -> jnp.array:
|
||||
"""
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
to the 1D array logits using the provided 1D array of tokens to penalize
|
||||
This gets called to apply repetition penalty to the 1D array logits
|
||||
using the provided 1D array of tokens to penalize
|
||||
"""
|
||||
rpslope = jnp.int32(cls.rep_pen_slope)
|
||||
rprange = jnp.int32(cls.rep_pen_range)
|
||||
|
Reference in New Issue
Block a user