mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: And another refactor
This commit is contained in:
11
aiserver.py
11
aiserver.py
@@ -534,8 +534,15 @@ koboldai_vars = koboldai_settings.koboldai_vars(socketio)
|
|||||||
utils.koboldai_vars = koboldai_vars
|
utils.koboldai_vars = koboldai_vars
|
||||||
utils.socketio = socketio
|
utils.socketio = socketio
|
||||||
|
|
||||||
# HACK: Weird import position to steal koboldai_vars from utils
|
# Weird import position to steal koboldai_vars from utils
|
||||||
from model import APIInferenceModel, GenericHFTorchInferenceModel, CustomGPT2HFTorchInferenceModel, HFMTJInferenceModel, HordeInferenceModel, OpenAIAPIInferenceModel, patch_transformers
|
from modeling.patches import patch_transformers
|
||||||
|
from modeling.inference_models.api import APIInferenceModel
|
||||||
|
from modeling.inference_models.generic_hf_torch import GenericHFTorchInferenceModel
|
||||||
|
from modeling.inference_models.legacy_gpt2_hf import CustomGPT2HFTorchInferenceModel
|
||||||
|
from modeling.inference_models.hf_mtj import HFMTJInferenceModel
|
||||||
|
from modeling.inference_models.horde import HordeInferenceModel
|
||||||
|
from modeling.inference_models.openai import OpenAIAPIInferenceModel
|
||||||
|
|
||||||
|
|
||||||
old_socketio_on = socketio.on
|
old_socketio_on = socketio.on
|
||||||
def new_socketio_on(*a, **k):
|
def new_socketio_on(*a, **k):
|
||||||
|
12
logger.py
12
logger.py
@@ -2,6 +2,18 @@ import sys
|
|||||||
from functools import partialmethod
|
from functools import partialmethod
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
# Yes this shouldn't be here but I couldn't really find a better place to put
|
||||||
|
# it barring creating a whole file just for this which is rather silly
|
||||||
|
class Colors:
|
||||||
|
PURPLE = "\033[95m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
END = "\033[0m"
|
||||||
|
UNDERLINE = "\033[4m"
|
||||||
|
|
||||||
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
|
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
|
||||||
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
|
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
|
||||||
MESSAGE_LEVELS = ["MESSAGE"]
|
MESSAGE_LEVELS = ["MESSAGE"]
|
||||||
|
591
modeling/inference_model.py
Normal file
591
modeling/inference_model.py
Normal file
@@ -0,0 +1,591 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
GPT2Tokenizer,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tpu_mtj_backend
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
# Not on TPU... hopefully
|
||||||
|
if utils.koboldai_vars.use_colab_tpu:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# I don't really like this way of pointing to the current model but I can't
|
||||||
|
# find a way around it in some areas.
|
||||||
|
current_model = None
|
||||||
|
|
||||||
|
# We only want to use logit manipulations and such on our core text model
|
||||||
|
class use_core_manipulations:
|
||||||
|
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||||
|
|
||||||
|
# These must be set by wherever they get setup
|
||||||
|
get_logits_processor: callable = None
|
||||||
|
sample: callable = None
|
||||||
|
get_stopping_criteria: callable = None
|
||||||
|
|
||||||
|
# We set these automatically
|
||||||
|
old_get_logits_processor: callable = None
|
||||||
|
old_sample: callable = None
|
||||||
|
old_get_stopping_criteria: callable = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if use_core_manipulations.get_logits_processor:
|
||||||
|
use_core_manipulations.old_get_logits_processor = (
|
||||||
|
transformers.GenerationMixin._get_logits_processor
|
||||||
|
)
|
||||||
|
transformers.GenerationMixin._get_logits_processor = (
|
||||||
|
use_core_manipulations.get_logits_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_core_manipulations.sample:
|
||||||
|
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
|
||||||
|
transformers.GenerationMixin.sample = use_core_manipulations.sample
|
||||||
|
|
||||||
|
if use_core_manipulations.get_stopping_criteria:
|
||||||
|
use_core_manipulations.old_get_stopping_criteria = (
|
||||||
|
transformers.GenerationMixin._get_stopping_criteria
|
||||||
|
)
|
||||||
|
transformers.GenerationMixin._get_stopping_criteria = (
|
||||||
|
use_core_manipulations.get_stopping_criteria
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
if use_core_manipulations.old_get_logits_processor:
|
||||||
|
transformers.GenerationMixin._get_logits_processor = (
|
||||||
|
use_core_manipulations.old_get_logits_processor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not use_core_manipulations.get_logits_processor
|
||||||
|
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||||
|
|
||||||
|
if use_core_manipulations.old_sample:
|
||||||
|
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not use_core_manipulations.sample
|
||||||
|
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||||
|
|
||||||
|
if use_core_manipulations.old_get_stopping_criteria:
|
||||||
|
transformers.GenerationMixin._get_stopping_criteria = (
|
||||||
|
use_core_manipulations.old_get_stopping_criteria
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not use_core_manipulations.get_stopping_criteria
|
||||||
|
), "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationResult:
|
||||||
|
"""A container for easily accessing different forms of model outputs. Returned by most generate functions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: InferenceModel,
|
||||||
|
out_batches: list,
|
||||||
|
prompt: list,
|
||||||
|
# Controls if generate() does it's looping thing. This should only be
|
||||||
|
# done for HF models that use that StoppingCondition
|
||||||
|
is_whole_generation: bool,
|
||||||
|
# Controls if we should trim output by prompt length
|
||||||
|
output_includes_prompt: bool = False,
|
||||||
|
# Lazy filter to cut off extra lines where we can't manipulate
|
||||||
|
# probabilities
|
||||||
|
single_line: bool = False,
|
||||||
|
):
|
||||||
|
# Shave prompt off of encoded response when needed (HF). Decoded does
|
||||||
|
# not return prompt.
|
||||||
|
if output_includes_prompt:
|
||||||
|
self.encoded = out_batches[:, len(prompt) :]
|
||||||
|
else:
|
||||||
|
self.encoded = out_batches
|
||||||
|
|
||||||
|
self.prompt = prompt
|
||||||
|
self.is_whole_generation = is_whole_generation
|
||||||
|
|
||||||
|
self.decoded = [
|
||||||
|
utils.decodenewlines(model.tokenizer.decode(enc)) for enc in self.encoded
|
||||||
|
]
|
||||||
|
|
||||||
|
if single_line:
|
||||||
|
self.decoded = [x.split("\n", 1)[0] for x in self.decoded]
|
||||||
|
self.encoded = np.array(model.tokenizer(self.decoded).input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationSettings:
|
||||||
|
"""Structure for holding temporarily overwritten settings."""
|
||||||
|
|
||||||
|
def __init__(self, **overrides) -> None:
|
||||||
|
for setting in [
|
||||||
|
"temp",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"tfs",
|
||||||
|
"typical",
|
||||||
|
"top_a",
|
||||||
|
"rep_pen",
|
||||||
|
"rep_pen_slope",
|
||||||
|
"rep_pen_range",
|
||||||
|
"sampler_order",
|
||||||
|
]:
|
||||||
|
setattr(
|
||||||
|
self,
|
||||||
|
setting,
|
||||||
|
overrides.get(setting, getattr(utils.koboldai_vars, setting)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapabilities:
|
||||||
|
embedding_manipulation: bool = False
|
||||||
|
post_token_hooks: bool = False
|
||||||
|
stopper_hooks: bool = False
|
||||||
|
# TODO: Support non-live probabilities from APIs
|
||||||
|
post_token_probs: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceModel:
|
||||||
|
"""Root class for all models."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.gen_state = {}
|
||||||
|
self.post_token_hooks = []
|
||||||
|
self.stopper_hooks = []
|
||||||
|
self.tokenizer = None
|
||||||
|
self.capabilties = ModelCapabilities()
|
||||||
|
|
||||||
|
def load(self, save_model: bool = False, initial_load: bool = False) -> None:
|
||||||
|
"""User-facing load function. Do not override this; try `_load()` instead."""
|
||||||
|
|
||||||
|
self._load(save_model=save_model, initial_load=initial_load)
|
||||||
|
self._post_load()
|
||||||
|
|
||||||
|
global current_model
|
||||||
|
current_model = self
|
||||||
|
|
||||||
|
print(self.raw_generate("Hi guys,", 20).__dict__)
|
||||||
|
|
||||||
|
def _post_load(self) -> None:
|
||||||
|
"""Post load hook. Called after `_load()`."""
|
||||||
|
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
"""Main load method. All logic related to loading the model onto the
|
||||||
|
selected device(s) and preparing it for inference should be implemented here."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_tokenizer(self, location: str) -> AutoTokenizer:
|
||||||
|
"""Returns the appropiate tokenizer for the location. Should be ran once and result stored in `tokenizer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): Either a local model directory path or a HuggingFace model ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AutoTokenizer: Tokenizer deemed fit for the location string. May be a fallback tokenizer.
|
||||||
|
"""
|
||||||
|
if utils.koboldai_vars.model_type == "xglm":
|
||||||
|
# Default to </s> newline mode if using XGLM
|
||||||
|
utils.koboldai_vars.newlinemode = "s"
|
||||||
|
elif utils.koboldai_vars.model_type in ["opt", "bloom"]:
|
||||||
|
# Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
|
||||||
|
utils.koboldai_vars.newlinemode = "ns"
|
||||||
|
|
||||||
|
std_kwargs = {"revision": utils.koboldai_vars.revision, "cache_dir": "cache"}
|
||||||
|
|
||||||
|
suppliers = [
|
||||||
|
# Fast tokenizer disabled by default as per HF docs:
|
||||||
|
# > Note: Make sure to pass use_fast=False when loading
|
||||||
|
# OPT’s tokenizer with AutoTokenizer to get the correct
|
||||||
|
# tokenizer.
|
||||||
|
lambda: AutoTokenizer.from_pretrained(
|
||||||
|
location, use_fast=False, **std_kwargs
|
||||||
|
),
|
||||||
|
lambda: AutoTokenizer.from_pretrained(location, **std_kwargs),
|
||||||
|
# Fallback to GPT2Tokenizer
|
||||||
|
lambda: GPT2Tokenizer.from_pretrained(location, **std_kwargs),
|
||||||
|
lambda: GPT2Tokenizer.from_pretrained("gpt2", **std_kwargs),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, try_get_tokenizer in enumerate(suppliers):
|
||||||
|
try:
|
||||||
|
return try_get_tokenizer()
|
||||||
|
except:
|
||||||
|
# If we error on each attempt, raise the last one
|
||||||
|
if i == len(suppliers) - 1:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def core_generate(
|
||||||
|
self,
|
||||||
|
text: list,
|
||||||
|
found_entries: set,
|
||||||
|
):
|
||||||
|
"""Generate story text. Heavily tied to story-specific parameters; if
|
||||||
|
you are making a new generation-based feature, consider `generate_raw()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (list): Encoded input tokens
|
||||||
|
found_entries (set): Entries found for Dynamic WI
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if inconsistancies are detected with the internal state and Lua state -- sanity check
|
||||||
|
RuntimeError: if inconsistancies are detected with the internal state and core stopper -- sanity check
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
gen_in = torch.tensor(text, dtype=torch.long)[None]
|
||||||
|
logger.debug(
|
||||||
|
"core_generate: torch.tensor time {}s".format(time.time() - start_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
if utils.koboldai_vars.is_model_torch():
|
||||||
|
# Torch stuff
|
||||||
|
if utils.koboldai_vars.full_determinism:
|
||||||
|
torch.manual_seed(utils.koboldai_vars.seed)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.sp is not None:
|
||||||
|
assert self.capabilties.embedding_manipulation
|
||||||
|
soft_tokens = torch.arange(
|
||||||
|
self.model.config.vocab_size,
|
||||||
|
self.model.config.vocab_size + utils.koboldai_vars.sp.shape[0],
|
||||||
|
)
|
||||||
|
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||||
|
elif utils.koboldai_vars.use_colab_tpu:
|
||||||
|
if utils.koboldai_vars.full_determinism:
|
||||||
|
tpu_mtj_backend.set_rng_seed(utils.koboldai_vars.seed)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"core_generate: Model Setup (SP, etc) time {}s".format(
|
||||||
|
time.time() - start_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
gen_in.shape[-1] + utils.koboldai_vars.genamt
|
||||||
|
> utils.koboldai_vars.max_length
|
||||||
|
):
|
||||||
|
logger.error("gen_in.shape[-1]: {}".format(gen_in.shape[-1]))
|
||||||
|
logger.error(
|
||||||
|
"utils.koboldai_vars.genamt: {}".format(utils.koboldai_vars.genamt)
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"utils.koboldai_vars.max_length: {}".format(
|
||||||
|
utils.koboldai_vars.max_length
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
gen_in.shape[-1] + utils.koboldai_vars.genamt
|
||||||
|
<= utils.koboldai_vars.max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
gen_in = gen_in.to(utils.get_auxilary_device())
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"core_generate: gen_in to device time {}s".format(time.time() - start_time)
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
found_entries = found_entries or set()
|
||||||
|
|
||||||
|
self.gen_state["wi_scanner_excluded_keys"] = found_entries
|
||||||
|
|
||||||
|
utils.koboldai_vars._prompt = utils.koboldai_vars.prompt
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
already_generated = 0
|
||||||
|
numseqs = utils.koboldai_vars.numseqs
|
||||||
|
total_gens = None
|
||||||
|
|
||||||
|
for i in range(
|
||||||
|
utils.koboldai_vars.numseqs if utils.koboldai_vars.alt_multi_gen else 1
|
||||||
|
):
|
||||||
|
while True:
|
||||||
|
# The reason this is a loop is due to how Dynamic WI works. We
|
||||||
|
# cannot simply add the WI to the context mid-generation, so we
|
||||||
|
# stop early, and then insert WI, then continue generating. That
|
||||||
|
# stopping and continuing is this loop.
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = self.raw_generate(
|
||||||
|
gen_in[0],
|
||||||
|
max_new=utils.koboldai_vars.genamt,
|
||||||
|
do_streaming=utils.koboldai_vars.output_streaming,
|
||||||
|
do_dynamic_wi=utils.koboldai_vars.dynamicscan,
|
||||||
|
batch_count=numseqs
|
||||||
|
if not utils.koboldai_vars.alt_multi_gen
|
||||||
|
else 1,
|
||||||
|
# Real max length is handled by CoreStopper.
|
||||||
|
bypass_hf_maxlength=utils.koboldai_vars.dynamicscan,
|
||||||
|
is_core=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"core_generate: run raw_generate pass {} {}s".format(
|
||||||
|
already_generated, time.time() - start_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
genout = result.encoded
|
||||||
|
|
||||||
|
already_generated += len(genout[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert (
|
||||||
|
already_generated
|
||||||
|
<= utils.koboldai_vars.genamt * utils.koboldai_vars.numseqs
|
||||||
|
if utils.koboldai_vars.alt_multi_gen
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
except AssertionError:
|
||||||
|
print("AlreadyGenerated", already_generated)
|
||||||
|
print("genamt", utils.koboldai_vars.genamt)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if result.is_whole_generation:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Generation stopped; why?
|
||||||
|
# If we have been told to halt, we have reached our target token
|
||||||
|
# amount (controlled by halt), or Dynamic WI has not told us to
|
||||||
|
# stop temporarily to insert WI, we can assume that we are done
|
||||||
|
# generating. We shall break.
|
||||||
|
if (
|
||||||
|
self.gen_state["halt"]
|
||||||
|
or not self.gen_state["regeneration_required"]
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Now we are doing stuff for Dynamic WI.
|
||||||
|
assert genout.ndim >= 2
|
||||||
|
assert genout.shape[0] == utils.koboldai_vars.numseqs
|
||||||
|
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
and utils.koboldai_vars.generated_tkns
|
||||||
|
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if already_generated != utils.koboldai_vars.generated_tkns:
|
||||||
|
print("already_generated: {}".format(already_generated))
|
||||||
|
print(
|
||||||
|
"generated_tkns: {}".format(
|
||||||
|
utils.koboldai_vars.generated_tkns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise RuntimeError("WI scanning error")
|
||||||
|
|
||||||
|
for r in range(utils.koboldai_vars.numseqs):
|
||||||
|
for c in range(already_generated):
|
||||||
|
assert (
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
|
||||||
|
c + 1
|
||||||
|
]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
genout[r][
|
||||||
|
genout.shape[-1] - already_generated + c
|
||||||
|
] = utils.koboldai_vars.lua_koboldbridge.generated[r + 1][
|
||||||
|
c + 1
|
||||||
|
]
|
||||||
|
|
||||||
|
encoded = []
|
||||||
|
|
||||||
|
for i in range(utils.koboldai_vars.numseqs):
|
||||||
|
txt = utils.decodenewlines(
|
||||||
|
self.tokenizer.decode(genout[i, -already_generated:])
|
||||||
|
)
|
||||||
|
# winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=utils.koboldai_vars.actions)
|
||||||
|
# txt, _, _ = calcsubmitbudget(len(utils.koboldai_vars.actions), winfo, mem, anotetxt, utils.koboldai_vars.actions, submission=txt)
|
||||||
|
txt, _, _, _found_entries = utils.koboldai_vars.calc_ai_text(
|
||||||
|
submitted_text=txt, send_context=False
|
||||||
|
)
|
||||||
|
found_entries[i].update(_found_entries)
|
||||||
|
encoded.append(
|
||||||
|
torch.tensor(txt, dtype=torch.long, device=genout.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
max_length = len(max(encoded, key=len))
|
||||||
|
encoded = torch.stack(
|
||||||
|
tuple(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
e,
|
||||||
|
(max_length - len(e), 0),
|
||||||
|
value=self.model.config.pad_token_id
|
||||||
|
or self.model.config.eos_token_id,
|
||||||
|
)
|
||||||
|
for e in encoded
|
||||||
|
)
|
||||||
|
)
|
||||||
|
genout = torch.cat(
|
||||||
|
(
|
||||||
|
encoded,
|
||||||
|
genout[..., -already_generated:],
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.sp is not None:
|
||||||
|
soft_tokens = torch.arange(
|
||||||
|
self.model.config.vocab_size,
|
||||||
|
self.model.config.vocab_size
|
||||||
|
+ utils.koboldai_vars.sp.shape[0],
|
||||||
|
device=genout.device,
|
||||||
|
)
|
||||||
|
genout = torch.cat(
|
||||||
|
(soft_tokens.tile(utils.koboldai_vars.numseqs, 1), genout),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
genout.shape[-1]
|
||||||
|
+ utils.koboldai_vars.genamt
|
||||||
|
- already_generated
|
||||||
|
<= utils.koboldai_vars.max_length
|
||||||
|
)
|
||||||
|
gen_in = genout
|
||||||
|
numseqs = 1
|
||||||
|
if total_gens is None:
|
||||||
|
total_gens = genout
|
||||||
|
else:
|
||||||
|
total_gens = torch.cat((total_gens, genout))
|
||||||
|
|
||||||
|
return total_gens, already_generated
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
) -> GenerationResult:
|
||||||
|
"""Lowest level model-agnostic generation function. To be overridden by model implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_tokens (Union[List[int], torch.Tensor]): Prompt as encoded token IDs
|
||||||
|
max_new (int): Maximum amount of new tokens to generate
|
||||||
|
gen_settings (GenerationSettings): State to pass in single-generation setting overrides
|
||||||
|
single_line (bool, optional): Generate one line only. Defaults to False.
|
||||||
|
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult: The model's output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def raw_generate(
|
||||||
|
self,
|
||||||
|
# prompt is either a string (text) or a list (token ids)
|
||||||
|
prompt: Union[str, list, np.ndarray],
|
||||||
|
max_new: int,
|
||||||
|
do_streaming: bool = False,
|
||||||
|
do_dynamic_wi: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
bypass_hf_maxlength: bool = False,
|
||||||
|
generation_settings: Optional[dict] = None,
|
||||||
|
is_core: bool = False,
|
||||||
|
single_line: bool = False,
|
||||||
|
found_entries: set = (),
|
||||||
|
) -> GenerationResult:
|
||||||
|
"""A wrapper around `_raw_generate()` that handles gen_state and other stuff. Use this to generate text outside of the story.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (Union[str, list, np.ndarray]): The prompt as a string or encoded token IDs
|
||||||
|
max_new (int): Maximum amount of new tokens to generate
|
||||||
|
do_streaming (bool, optional): Whether to stream tokens to the user or not. Defaults to False.
|
||||||
|
do_dynamic_wi (bool, optional): Whether to use Dynamic WI context injections. Defaults to False.
|
||||||
|
batch_count (int, optional): How big of a batch to generate. Defaults to 1.
|
||||||
|
bypass_hf_maxlength (bool, optional): Whether to ignore model-provided max length limits. Defaults to False.
|
||||||
|
generation_settings (GenerationSettings): State to pass in single-generation setting overrides. Defaults to None
|
||||||
|
is_core (bool, optional): Whether this generation is a core story generation. Defaults to False.
|
||||||
|
single_line (bool, optional): Generate one line only.. Defaults to False.
|
||||||
|
found_entries (set, optional): Entries found for Dynamic WI. Defaults to ().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt type is weird
|
||||||
|
NotImplementedError: If model is ReadOnly
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult: The model's output
|
||||||
|
"""
|
||||||
|
# TODO: Support singleline outside of torch
|
||||||
|
|
||||||
|
self.gen_state["do_streaming"] = do_streaming
|
||||||
|
self.gen_state["do_dynamic_wi"] = do_dynamic_wi
|
||||||
|
|
||||||
|
# Dynamic WI depends on this!!! This is a main gen call.
|
||||||
|
self.gen_state["stop_at_genamt"] = do_dynamic_wi
|
||||||
|
|
||||||
|
# Makes stopping criteria hook happy
|
||||||
|
self.gen_state["wi_scanner_excluded_keys"] = self.gen_state.get(
|
||||||
|
"wi_scanner_excluded_keys", set()
|
||||||
|
)
|
||||||
|
|
||||||
|
utils.koboldai_vars.inference_config.do_core = is_core
|
||||||
|
gen_settings = GenerationSettings(*(generation_settings or {}))
|
||||||
|
|
||||||
|
if isinstance(prompt, torch.Tensor):
|
||||||
|
prompt_tokens = prompt.cpu().numpy()
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
prompt_tokens = np.array(prompt)
|
||||||
|
elif isinstance(prompt, str):
|
||||||
|
prompt_tokens = np.array(self.tokenizer.encode(prompt))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Prompt is {type(prompt)}. Not a fan!")
|
||||||
|
|
||||||
|
assert isinstance(prompt_tokens, np.ndarray)
|
||||||
|
assert len(prompt_tokens.shape) == 1
|
||||||
|
|
||||||
|
if utils.koboldai_vars.model == "ReadOnly":
|
||||||
|
raise NotImplementedError("No loaded model")
|
||||||
|
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
with use_core_manipulations():
|
||||||
|
result = self._raw_generate(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
max_new=max_new,
|
||||||
|
batch_count=batch_count,
|
||||||
|
gen_settings=gen_settings,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
time_end = round(time.time() - time_start, 2)
|
||||||
|
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.quiet:
|
||||||
|
logger.info(
|
||||||
|
f"Generated {len(result.encoded[0])} tokens in {time_end} seconds, for an average rate of {tokens_per_second} tokens per second."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new_tokens: int,
|
||||||
|
do_streaming: bool = False,
|
||||||
|
do_dynamic_wi: bool = False,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _post_token_gen(self, input_ids: torch.LongTensor) -> None:
|
||||||
|
for hook in self.post_token_hooks:
|
||||||
|
hook(self, input_ids)
|
85
modeling/inference_models/api.py
Normal file
85
modeling/inference_models/api.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
from modeling.inference_model import (
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSettings,
|
||||||
|
InferenceModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIException(Exception):
|
||||||
|
"""To be used for errors when using the Kobold API as an interface."""
|
||||||
|
|
||||||
|
|
||||||
|
class APIInferenceModel(InferenceModel):
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
tokenizer_id = requests.get(
|
||||||
|
utils.koboldai_vars.colaburl[:-8] + "/api/v1/model",
|
||||||
|
).json()["result"]
|
||||||
|
self.tokenizer = self._get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
):
|
||||||
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
utils.koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
reqdata = {
|
||||||
|
"prompt": decoded_prompt,
|
||||||
|
"max_length": max_new,
|
||||||
|
"max_context_length": utils.koboldai_vars.max_length,
|
||||||
|
"rep_pen": gen_settings.rep_pen,
|
||||||
|
"rep_pen_slope": gen_settings.rep_pen_slope,
|
||||||
|
"rep_pen_range": gen_settings.rep_pen_range,
|
||||||
|
"temperature": gen_settings.temp,
|
||||||
|
"top_p": gen_settings.top_p,
|
||||||
|
"top_k": gen_settings.top_k,
|
||||||
|
"top_a": gen_settings.top_a,
|
||||||
|
"tfs": gen_settings.tfs,
|
||||||
|
"typical": gen_settings.typical,
|
||||||
|
"n": batch_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
while True:
|
||||||
|
req = requests.post(
|
||||||
|
utils.koboldai_vars.colaburl[:-8] + "/api/v1/generate",
|
||||||
|
json=reqdata,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
req.status_code == 503
|
||||||
|
): # Server is currently generating something else so poll until it's our turn
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
js = req.json()
|
||||||
|
if req.status_code != 200:
|
||||||
|
logger.error(json.dumps(js, indent=4))
|
||||||
|
raise APIException(f"Bad API status code {req.status_code}")
|
||||||
|
|
||||||
|
genout = [obj["text"] for obj in js["results"]]
|
||||||
|
return GenerationResult(
|
||||||
|
model=self,
|
||||||
|
out_batches=np.array([self.tokenizer.encode(x) for x in genout]),
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=True,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
77
modeling/inference_models/colab.py
Normal file
77
modeling/inference_models/colab.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
from modeling.inference_model import (
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSettings,
|
||||||
|
InferenceModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ColabException(Exception):
|
||||||
|
"""To be used for errors when using the Colab API as an interface."""
|
||||||
|
|
||||||
|
|
||||||
|
class ColabInferenceModel(InferenceModel):
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
self.tokenizer = self._get_tokenizer("EleutherAI/gpt-neo-2.7B")
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
):
|
||||||
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
utils.koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
reqdata = {
|
||||||
|
"text": decoded_prompt,
|
||||||
|
"min": 0,
|
||||||
|
"max": max_new,
|
||||||
|
"rep_pen": gen_settings.rep_pen,
|
||||||
|
"rep_pen_slope": gen_settings.rep_pen_slope,
|
||||||
|
"rep_pen_range": gen_settings.rep_pen_range,
|
||||||
|
"temperature": gen_settings.temp,
|
||||||
|
"top_p": gen_settings.top_p,
|
||||||
|
"top_k": gen_settings.top_k,
|
||||||
|
"tfs": gen_settings.tfs,
|
||||||
|
"typical": gen_settings.typical,
|
||||||
|
"topa": gen_settings.top_a,
|
||||||
|
"numseqs": batch_count,
|
||||||
|
"retfultxt": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
req = requests.post(utils.koboldai_vars.colaburl, json=reqdata)
|
||||||
|
|
||||||
|
if req.status_code != 200:
|
||||||
|
raise ColabException(f"Bad status code {req.status_code}")
|
||||||
|
|
||||||
|
# Deal with the response
|
||||||
|
js = req.json()["data"]
|
||||||
|
|
||||||
|
# Try to be backwards compatible with outdated colab
|
||||||
|
if "text" in js:
|
||||||
|
genout = [utils.getnewcontent(js["text"], self.tokenizer)]
|
||||||
|
else:
|
||||||
|
genout = js["seqs"]
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
model=self,
|
||||||
|
out_batches=np.array([self.tokenizer.encode(x) for x in genout]),
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=True,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
262
modeling/inference_models/generic_hf_torch.py
Normal file
262
modeling/inference_models/generic_hf_torch.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
|
|
||||||
|
import utils
|
||||||
|
import breakmodel
|
||||||
|
import torch_lazy_loader
|
||||||
|
import koboldai_settings
|
||||||
|
|
||||||
|
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class GenericHFTorchInferenceModel(HFTorchInferenceModel):
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
utils.koboldai_vars.allowsp = True
|
||||||
|
|
||||||
|
# Make model path the same as the model name to make this consistent
|
||||||
|
# with the other loading method if it isn't a known model type. This
|
||||||
|
# code is not just a workaround for below, it is also used to make the
|
||||||
|
# behavior consistent with other loading methods - Henk717
|
||||||
|
# if utils.koboldai_vars.model not in ["NeoCustom", "GPT2Custom"]:
|
||||||
|
# utils.koboldai_vars.custmodpth = utils.koboldai_vars.model
|
||||||
|
|
||||||
|
if utils.koboldai_vars.model == "NeoCustom":
|
||||||
|
utils.koboldai_vars.model = os.path.basename(
|
||||||
|
os.path.normpath(utils.koboldai_vars.custmodpth)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we specify a model and it's in the root directory, we need to move
|
||||||
|
# it to the models directory (legacy folder structure to new)
|
||||||
|
if self.get_local_model_path(legacy=True):
|
||||||
|
shutil.move(
|
||||||
|
self.get_local_model_path(legacy=True, ignore_existance=True),
|
||||||
|
self.get_local_model_path(ignore_existance=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_model_config()
|
||||||
|
|
||||||
|
tf_kwargs = {
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if utils.koboldai_vars.model_type == "gpt2":
|
||||||
|
# We must disable low_cpu_mem_usage and if using a GPT-2 model
|
||||||
|
# because GPT-2 is not compatible with this feature yet.
|
||||||
|
tf_kwargs.pop("low_cpu_mem_usage", None)
|
||||||
|
|
||||||
|
# Also, lazy loader doesn't support GPT-2 models
|
||||||
|
utils.koboldai_vars.lazy_load = False
|
||||||
|
|
||||||
|
# If we're using torch_lazy_loader, we need to get breakmodel config
|
||||||
|
# early so that it knows where to load the individual model tensors
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.lazy_load
|
||||||
|
and utils.koboldai_vars.hascuda
|
||||||
|
and utils.koboldai_vars.breakmodel
|
||||||
|
and not utils.koboldai_vars.nobreakmodel
|
||||||
|
):
|
||||||
|
self.breakmodel_device_config(self.model_config)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.lazy_load:
|
||||||
|
# If we're using lazy loader, we need to figure out what the model's hidden layers are called
|
||||||
|
with torch_lazy_loader.use_lazy_torch_load(
|
||||||
|
dematerialized_modules=True, use_accelerate_init_empty_weights=True
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
metamodel = AutoModelForCausalLM.from_config(self.model_config)
|
||||||
|
except Exception as e:
|
||||||
|
metamodel = GPTNeoForCausalLM.from_config(self.model_config)
|
||||||
|
utils.layers_module_names = utils.get_layers_module_names(metamodel)
|
||||||
|
utils.module_names = list(metamodel.state_dict().keys())
|
||||||
|
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
|
||||||
|
|
||||||
|
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||||
|
with self._maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(
|
||||||
|
enable=utils.koboldai_vars.lazy_load,
|
||||||
|
callback=self._get_lazy_load_callback(utils.num_layers(self.model_config))
|
||||||
|
if utils.koboldai_vars.lazy_load
|
||||||
|
else None,
|
||||||
|
dematerialized_modules=True,
|
||||||
|
):
|
||||||
|
if utils.koboldai_vars.lazy_load:
|
||||||
|
# torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||||
|
tf_kwargs.pop("low_cpu_mem_usage", None)
|
||||||
|
|
||||||
|
if self.get_local_model_path():
|
||||||
|
# Model is stored locally, load it.
|
||||||
|
self.model = self._get_model(self.get_local_model_path(), tf_kwargs)
|
||||||
|
self.tokenizer = self._get_tokenizer(self.get_local_model_path())
|
||||||
|
else:
|
||||||
|
# Model not stored locally, we need to download it.
|
||||||
|
|
||||||
|
# _rebuild_tensor patch for casting dtype and supporting LazyTensors
|
||||||
|
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||||
|
|
||||||
|
def new_rebuild_tensor(
|
||||||
|
storage: Union[torch_lazy_loader.LazyTensor, torch.Storage],
|
||||||
|
storage_offset,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
):
|
||||||
|
if not isinstance(storage, torch_lazy_loader.LazyTensor):
|
||||||
|
dtype = storage.dtype
|
||||||
|
else:
|
||||||
|
dtype = storage.storage_type.dtype
|
||||||
|
if not isinstance(dtype, torch.dtype):
|
||||||
|
dtype = storage.storage_type(0).dtype
|
||||||
|
if dtype is torch.float32 and len(shape) >= 2:
|
||||||
|
utils.koboldai_vars.fp32_model = True
|
||||||
|
return old_rebuild_tensor(storage, storage_offset, shape, stride)
|
||||||
|
|
||||||
|
torch._utils._rebuild_tensor = new_rebuild_tensor
|
||||||
|
self.model = self._get_model(utils.koboldai_vars.model, tf_kwargs)
|
||||||
|
self.tokenizer = self._get_tokenizer(utils.koboldai_vars.model)
|
||||||
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
|
|
||||||
|
if save_model:
|
||||||
|
self.tokenizer.save_pretrained(
|
||||||
|
self.get_local_model_path(ignore_existance=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.fp32_model and not breakmodel.disk_blocks:
|
||||||
|
# Use save_pretrained to convert fp32 models to fp16,
|
||||||
|
# unless we are using disk cache because save_pretrained
|
||||||
|
# is not supported in that case
|
||||||
|
model = model.half()
|
||||||
|
model.save_pretrained(
|
||||||
|
self.get_local_model_path(ignore_existance=True),
|
||||||
|
max_shard_size="500MiB",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# For fp16 models, we can just copy the model files directly
|
||||||
|
import transformers.configuration_utils
|
||||||
|
import transformers.modeling_utils
|
||||||
|
import transformers.file_utils
|
||||||
|
import huggingface_hub
|
||||||
|
|
||||||
|
# Save the config.json
|
||||||
|
shutil.move(
|
||||||
|
os.path.realpath(
|
||||||
|
huggingface_hub.hf_hub_download(
|
||||||
|
utils.koboldai_vars.model,
|
||||||
|
transformers.configuration_utils.CONFIG_NAME,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
local_files_only=True,
|
||||||
|
legacy_cache_layout=False,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.get_local_model_path(ignore_existance=True),
|
||||||
|
transformers.configuration_utils.CONFIG_NAME,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if utils.num_shards is None:
|
||||||
|
# Save the pytorch_model.bin or model.safetensors of an unsharded model
|
||||||
|
for possible_weight_name in [
|
||||||
|
transformers.modeling_utils.WEIGHTS_NAME,
|
||||||
|
"model.safetensors",
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
shutil.move(
|
||||||
|
os.path.realpath(
|
||||||
|
huggingface_hub.hf_hub_download(
|
||||||
|
utils.koboldai_vars.model,
|
||||||
|
possible_weight_name,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
local_files_only=True,
|
||||||
|
legacy_cache_layout=False,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.get_local_model_path(
|
||||||
|
ignore_existance=True
|
||||||
|
),
|
||||||
|
possible_weight_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
if possible_weight_name == "model.safetensors":
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Handle saving sharded models
|
||||||
|
|
||||||
|
with open(utils.from_pretrained_index_filename) as f:
|
||||||
|
map_data = json.load(f)
|
||||||
|
filenames = set(map_data["weight_map"].values())
|
||||||
|
# Save the pytorch_model.bin.index.json of a sharded model
|
||||||
|
shutil.move(
|
||||||
|
os.path.realpath(utils.from_pretrained_index_filename),
|
||||||
|
os.path.join(
|
||||||
|
self.get_local_model_path(ignore_existance=True),
|
||||||
|
transformers.modeling_utils.WEIGHTS_INDEX_NAME,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Then save the pytorch_model-#####-of-#####.bin files
|
||||||
|
for filename in filenames:
|
||||||
|
shutil.move(
|
||||||
|
os.path.realpath(
|
||||||
|
huggingface_hub.hf_hub_download(
|
||||||
|
utils.koboldai_vars.model,
|
||||||
|
filename,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
local_files_only=True,
|
||||||
|
legacy_cache_layout=False,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
os.path.join(
|
||||||
|
self.get_local_model_path(
|
||||||
|
ignore_existance=True
|
||||||
|
),
|
||||||
|
filename,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
shutil.rmtree("cache/")
|
||||||
|
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
||||||
|
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
|
):
|
||||||
|
utils.koboldai_vars.badwordsids = [
|
||||||
|
[v]
|
||||||
|
for k, v in self.tokenizer.get_vocab().items()
|
||||||
|
if any(c in str(k) for c in "<>[]")
|
||||||
|
if utils.koboldai_vars.newlinemode != "s" or str(k) != "</s>"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.patch_embedding()
|
||||||
|
|
||||||
|
if utils.koboldai_vars.hascuda:
|
||||||
|
if utils.koboldai_vars.usegpu:
|
||||||
|
# Use just VRAM
|
||||||
|
self.model = self.model.half().to(utils.koboldai_vars.gpu_device)
|
||||||
|
elif utils.koboldai_vars.breakmodel:
|
||||||
|
# Use both RAM and VRAM (breakmodel)
|
||||||
|
if not utils.koboldai_vars.lazy_load:
|
||||||
|
self.breakmodel_device_config(model.config)
|
||||||
|
self._move_to_devices()
|
||||||
|
elif breakmodel.disk_blocks > 0:
|
||||||
|
# Use disk
|
||||||
|
self._move_to_devices()
|
||||||
|
elif breakmodel.disk_blocks > 0:
|
||||||
|
self._move_to_devices()
|
||||||
|
else:
|
||||||
|
# Use CPU
|
||||||
|
self.model = self.model.to("cpu").float()
|
||||||
|
elif breakmodel.disk_blocks > 0:
|
||||||
|
self._move_to_devices()
|
||||||
|
else:
|
||||||
|
self.model = self.model.to("cpu").float()
|
||||||
|
self.model.kai_model = self
|
||||||
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
52
modeling/inference_models/hf.py
Normal file
52
modeling/inference_models/hf.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from logger import logger
|
||||||
|
from modeling.inference_model import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class HFInferenceModel(InferenceModel):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model_config = None
|
||||||
|
|
||||||
|
def get_local_model_path(
|
||||||
|
self, legacy: bool = False, ignore_existance: bool = False
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Returns a string of the model's path locally, or None if it is not downloaded.
|
||||||
|
If ignore_existance is true, it will always return a path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
basename = utils.koboldai_vars.model.replace("/", "_")
|
||||||
|
if legacy:
|
||||||
|
ret = basename
|
||||||
|
else:
|
||||||
|
ret = os.path.join("models", basename)
|
||||||
|
|
||||||
|
if os.path.isdir(ret) or ignore_existance:
|
||||||
|
return ret
|
||||||
|
return None
|
||||||
|
|
||||||
|
def init_model_config(self) -> None:
|
||||||
|
# Get the model_type from the config or assume a model type if it isn't present
|
||||||
|
try:
|
||||||
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
|
self.get_local_model_path() or utils.koboldai_vars.model,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
)
|
||||||
|
utils.koboldai_vars.model_type = self.model_config.model_type
|
||||||
|
except ValueError:
|
||||||
|
utils.koboldai_vars.model_type = {
|
||||||
|
"NeoCustom": "gpt_neo",
|
||||||
|
"GPT2Custom": "gpt2",
|
||||||
|
}.get(utils.koboldai_vars.model)
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.model_type:
|
||||||
|
logger.warning(
|
||||||
|
"No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)"
|
||||||
|
)
|
||||||
|
utils.koboldai_vars.model_type = "gpt_neo"
|
289
modeling/inference_models/hf_mtj.py
Normal file
289
modeling/inference_models/hf_mtj.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from eventlet import tpool
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import utils
|
||||||
|
import koboldai_settings
|
||||||
|
from logger import logger, Colors
|
||||||
|
|
||||||
|
from modeling.inference_model import ModelCapabilities
|
||||||
|
from modeling.inference_models.hf import HFInferenceModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tpu_mtj_backend
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
# Not on TPU... hopefully
|
||||||
|
if utils.koboldai_vars.use_colab_tpu:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class HFMTJInferenceModel(HFInferenceModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model_config = None
|
||||||
|
self.capabilties = ModelCapabilities(
|
||||||
|
embedding_manipulation=False,
|
||||||
|
post_token_hooks=False,
|
||||||
|
stopper_hooks=False,
|
||||||
|
post_token_probs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup_mtj(self) -> None:
|
||||||
|
def mtj_warper_callback(scores) -> "np.array":
|
||||||
|
scores_shape = scores.shape
|
||||||
|
scores_list = scores.tolist()
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.logits = (
|
||||||
|
utils.koboldai_vars.lua_state.table()
|
||||||
|
)
|
||||||
|
for r, row in enumerate(scores_list):
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.logits[
|
||||||
|
r + 1
|
||||||
|
] = utils.koboldai_vars.lua_state.table(*row)
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||||
|
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.execute_genmod()
|
||||||
|
|
||||||
|
scores = np.array(
|
||||||
|
tuple(
|
||||||
|
tuple(row.values())
|
||||||
|
for row in utils.koboldai_vars.lua_koboldbridge.logits.values()
|
||||||
|
),
|
||||||
|
dtype=scores.dtype,
|
||||||
|
)
|
||||||
|
assert scores.shape == scores_shape
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def mtj_stopping_callback(
|
||||||
|
generated, n_generated, excluded_world_info
|
||||||
|
) -> Tuple[List[set], bool, bool]:
|
||||||
|
utils.koboldai_vars.generated_tkns += 1
|
||||||
|
|
||||||
|
assert len(excluded_world_info) == len(generated)
|
||||||
|
regeneration_required = (
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.regeneration_required
|
||||||
|
)
|
||||||
|
halt = (
|
||||||
|
utils.koboldai_vars.abort
|
||||||
|
or not utils.koboldai_vars.lua_koboldbridge.generating
|
||||||
|
or utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
||||||
|
)
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
# Not sure what the deal is with this variable. It's been undefined
|
||||||
|
# as far back as I can trace it.
|
||||||
|
global past
|
||||||
|
|
||||||
|
for i in range(utils.koboldai_vars.numseqs):
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
|
||||||
|
utils.koboldai_vars.generated_tkns
|
||||||
|
] = int(
|
||||||
|
generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.dynamicscan or halt:
|
||||||
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
|
for i, t in enumerate(generated):
|
||||||
|
decoded = utils.decodenewlines(
|
||||||
|
self.tokenizer.decode(past[i])
|
||||||
|
) + utils.decodenewlines(
|
||||||
|
self.tokenizer.decode(
|
||||||
|
t[
|
||||||
|
tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params[
|
||||||
|
"seq"
|
||||||
|
]
|
||||||
|
+ n_generated
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# _, found = checkworldinfo(decoded, force_use_txt=True, actions=koboldai_vars.actions)
|
||||||
|
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||||
|
submitted_text=decoded
|
||||||
|
)
|
||||||
|
found -= excluded_world_info[i]
|
||||||
|
if len(found) != 0:
|
||||||
|
regeneration_required = True
|
||||||
|
break
|
||||||
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
|
def mtj_compiling_callback() -> None:
|
||||||
|
print(Colors.GREEN + "TPU backend compilation triggered" + Colors.END)
|
||||||
|
utils.koboldai_vars.compiling = True
|
||||||
|
|
||||||
|
def mtj_stopped_compiling_callback() -> None:
|
||||||
|
print(Colors.GREEN + "TPU backend compilation stopped" + Colors.END)
|
||||||
|
utils.koboldai_vars.compiling = False
|
||||||
|
|
||||||
|
def mtj_settings_callback() -> dict:
|
||||||
|
sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||||
|
if (
|
||||||
|
len(sampler_order) < 7
|
||||||
|
): # Add repetition penalty at beginning if it's not present
|
||||||
|
sampler_order = [6] + sampler_order
|
||||||
|
return {
|
||||||
|
"sampler_order": utils.koboldai_vars.sampler_order,
|
||||||
|
"top_p": float(utils.koboldai_vars.top_p),
|
||||||
|
"temp": float(utils.koboldai_vars.temp),
|
||||||
|
"top_k": int(utils.koboldai_vars.top_k),
|
||||||
|
"tfs": float(utils.koboldai_vars.tfs),
|
||||||
|
"typical": float(utils.koboldai_vars.typical),
|
||||||
|
"top_a": float(utils.koboldai_vars.top_a),
|
||||||
|
"repetition_penalty": float(utils.koboldai_vars.rep_pen),
|
||||||
|
"rpslope": float(utils.koboldai_vars.rep_pen_slope),
|
||||||
|
"rprange": int(utils.koboldai_vars.rep_pen_range),
|
||||||
|
}
|
||||||
|
|
||||||
|
tpu_mtj_backend.socketio = utils.socketio
|
||||||
|
|
||||||
|
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
|
||||||
|
utils.koboldai_vars.badwordsids = utils.koboldai_vars.badwordsids_neox
|
||||||
|
|
||||||
|
print(
|
||||||
|
"{0}Initializing Mesh Transformer JAX, please wait...{1}".format(
|
||||||
|
Colors.PURPLE, Colors.END
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if utils.koboldai_vars.model in (
|
||||||
|
"TPUMeshTransformerGPTJ",
|
||||||
|
"TPUMeshTransformerGPTNeoX",
|
||||||
|
) and (
|
||||||
|
not utils.koboldai_vars.custmodpth
|
||||||
|
or not os.path.isdir(utils.koboldai_vars.custmodpth)
|
||||||
|
):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"The specified model path {repr(utils.koboldai_vars.custmodpth)} is not the path to a valid folder"
|
||||||
|
)
|
||||||
|
if utils.koboldai_vars.model == "TPUMeshTransformerGPTNeoX":
|
||||||
|
tpu_mtj_backend.pad_token_id = 2
|
||||||
|
|
||||||
|
tpu_mtj_backend.koboldai_vars = utils.koboldai_vars
|
||||||
|
tpu_mtj_backend.warper_callback = mtj_warper_callback
|
||||||
|
tpu_mtj_backend.stopping_callback = mtj_stopping_callback
|
||||||
|
tpu_mtj_backend.compiling_callback = mtj_compiling_callback
|
||||||
|
tpu_mtj_backend.stopped_compiling_callback = mtj_stopped_compiling_callback
|
||||||
|
tpu_mtj_backend.settings_callback = mtj_settings_callback
|
||||||
|
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
self.setup_mtj()
|
||||||
|
self.init_model_config()
|
||||||
|
utils.koboldai_vars.allowsp = True
|
||||||
|
|
||||||
|
tpu_mtj_backend.load_model(
|
||||||
|
utils.koboldai_vars.model,
|
||||||
|
hf_checkpoint=utils.koboldai_vars.model
|
||||||
|
not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
|
||||||
|
and utils.koboldai_vars.use_colab_tpu,
|
||||||
|
socketio_queue=koboldai_settings.queue,
|
||||||
|
initial_load=initial_load,
|
||||||
|
logger=logger,
|
||||||
|
**self.model_config.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
utils.koboldai_vars.modeldim = int(
|
||||||
|
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.badwordsids is koboldai_settings.badwordsids_default
|
||||||
|
and utils.koboldai_vars.model_type not in ("gpt2", "gpt_neo", "gptj")
|
||||||
|
):
|
||||||
|
utils.koboldai_vars.badwordsids = [
|
||||||
|
[v]
|
||||||
|
for k, v in self.tokenizer.get_vocab().items()
|
||||||
|
if any(c in str(k) for c in "<>[]")
|
||||||
|
if utils.koboldai_vars.newlinemode != "s" or str(k) != "</s>"
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_soft_tokens(self) -> np.array:
|
||||||
|
soft_tokens = None
|
||||||
|
|
||||||
|
if utils.koboldai_vars.sp is None:
|
||||||
|
tensor = np.zeros(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
tpu_mtj_backend.params.get(
|
||||||
|
"d_embed", tpu_mtj_backend.params["d_model"]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
rows = tensor.shape[0]
|
||||||
|
padding_amount = (
|
||||||
|
tpu_mtj_backend.params["seq"]
|
||||||
|
- (
|
||||||
|
tpu_mtj_backend.params["seq"]
|
||||||
|
% -tpu_mtj_backend.params["cores_per_replica"]
|
||||||
|
)
|
||||||
|
- rows
|
||||||
|
)
|
||||||
|
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||||
|
tensor = tensor.reshape(
|
||||||
|
tpu_mtj_backend.params["cores_per_replica"],
|
||||||
|
-1,
|
||||||
|
tpu_mtj_backend.params.get(
|
||||||
|
"d_embed", tpu_mtj_backend.params["d_model"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
utils.koboldai_vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||||
|
|
||||||
|
soft_tokens = np.arange(
|
||||||
|
tpu_mtj_backend.params["n_vocab"]
|
||||||
|
+ tpu_mtj_backend.params["n_vocab_padding"],
|
||||||
|
tpu_mtj_backend.params["n_vocab"]
|
||||||
|
+ tpu_mtj_backend.params["n_vocab_padding"]
|
||||||
|
+ utils.koboldai_vars.sp_length,
|
||||||
|
dtype=np.uint32,
|
||||||
|
)
|
||||||
|
return soft_tokens
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
) -> GenerationResult:
|
||||||
|
soft_tokens = self.get_soft_tokens()
|
||||||
|
|
||||||
|
genout = tpool.execute(
|
||||||
|
tpu_mtj_backend.infer_static,
|
||||||
|
np.uint32(prompt_tokens),
|
||||||
|
gen_len=max_new,
|
||||||
|
temp=gen_settings.temp,
|
||||||
|
top_p=gen_settings.top_p,
|
||||||
|
top_k=gen_settings.top_k,
|
||||||
|
tfs=gen_settings.tfs,
|
||||||
|
typical=gen_settings.typical,
|
||||||
|
top_a=gen_settings.top_a,
|
||||||
|
numseqs=batch_count,
|
||||||
|
repetition_penalty=gen_settings.rep_pen,
|
||||||
|
rpslope=gen_settings.rep_pen_slope,
|
||||||
|
rprange=gen_settings.rep_pen_range,
|
||||||
|
soft_embeddings=utils.koboldai_vars.sp,
|
||||||
|
soft_tokens=soft_tokens,
|
||||||
|
sampler_order=gen_settings.sampler_order,
|
||||||
|
)
|
||||||
|
genout = np.array(genout)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
self,
|
||||||
|
out_batches=genout,
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=True,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
1053
modeling/inference_models/hf_torch.py
Normal file
1053
modeling/inference_models/hf_torch.py
Normal file
File diff suppressed because it is too large
Load Diff
167
modeling/inference_models/horde.py
Normal file
167
modeling/inference_models/horde.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
from modeling.inference_model import (
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSettings,
|
||||||
|
InferenceModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HordeException(Exception):
|
||||||
|
"""To be used for errors on server side of the Horde."""
|
||||||
|
|
||||||
|
|
||||||
|
class HordeInferenceModel(InferenceModel):
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
self.tokenizer = self._get_tokenizer(
|
||||||
|
utils.koboldai_vars.cluster_requested_models[0]
|
||||||
|
if len(utils.koboldai_vars.cluster_requested_models) > 0
|
||||||
|
else "gpt2",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
) -> GenerationResult:
|
||||||
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
utils.koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
reqdata = {
|
||||||
|
"max_length": max_new,
|
||||||
|
"max_context_length": utils.koboldai_vars.max_length,
|
||||||
|
"rep_pen": gen_settings.rep_pen,
|
||||||
|
"rep_pen_slope": gen_settings.rep_pen_slope,
|
||||||
|
"rep_pen_range": gen_settings.rep_pen_range,
|
||||||
|
"temperature": gen_settings.temp,
|
||||||
|
"top_p": gen_settings.top_p,
|
||||||
|
"top_k": int(gen_settings.top_k),
|
||||||
|
"top_a": gen_settings.top_a,
|
||||||
|
"tfs": gen_settings.tfs,
|
||||||
|
"typical": gen_settings.typical,
|
||||||
|
"n": batch_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_metadata = {
|
||||||
|
"prompt": decoded_prompt,
|
||||||
|
"params": reqdata,
|
||||||
|
"models": [x for x in utils.koboldai_vars.cluster_requested_models if x],
|
||||||
|
"trusted_workers": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
client_agent = "KoboldAI:2.0.0:koboldai.org"
|
||||||
|
cluster_headers = {
|
||||||
|
"apikey": utils.koboldai_vars.horde_api_key,
|
||||||
|
"Client-Agent": client_agent,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create request
|
||||||
|
req = requests.post(
|
||||||
|
utils.koboldai_vars.colaburl[:-8] + "/api/v2/generate/text/async",
|
||||||
|
json=cluster_metadata,
|
||||||
|
headers=cluster_headers,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
errmsg = f"Horde unavailable. Please try again later"
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
if req.status_code == 503:
|
||||||
|
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
elif not req.ok:
|
||||||
|
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||||
|
logger.error(errmsg)
|
||||||
|
logger.error(f"HTTP {req.status_code}!!!")
|
||||||
|
logger.error(req.text)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_status = req.json()
|
||||||
|
except requests.exceptions.JSONDecodeError:
|
||||||
|
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
request_id = req_status["id"]
|
||||||
|
logger.debug("Horde Request ID: {}".format(request_id))
|
||||||
|
|
||||||
|
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
cluster_agent_headers = {"Client-Agent": client_agent}
|
||||||
|
|
||||||
|
while not finished:
|
||||||
|
try:
|
||||||
|
req = requests.get(
|
||||||
|
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}",
|
||||||
|
headers=cluster_agent_headers,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
errmsg = f"Horde unavailable. Please try again later"
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
if not req.ok:
|
||||||
|
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||||
|
logger.error(req.text)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_status = req.json()
|
||||||
|
except requests.exceptions.JSONDecodeError:
|
||||||
|
errmsg = (
|
||||||
|
f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
|
||||||
|
)
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
if "done" not in req_status:
|
||||||
|
errmsg = f"Unexpected response received from the KoboldAI Horde: '{req_status}'"
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
finished = req_status["done"]
|
||||||
|
utils.koboldai_vars.horde_wait_time = req_status["wait_time"]
|
||||||
|
utils.koboldai_vars.horde_queue_position = req_status["queue_position"]
|
||||||
|
utils.koboldai_vars.horde_queue_size = req_status["waiting"]
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
logger.debug(req_status)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
logger.debug("Last Horde Status Message: {}".format(req_status))
|
||||||
|
|
||||||
|
if req_status["faulted"]:
|
||||||
|
raise HordeException("Horde Text generation faulted! Please try again.")
|
||||||
|
|
||||||
|
generations = req_status["generations"]
|
||||||
|
gen_servers = [(cgen["worker_name"], cgen["worker_id"]) for cgen in generations]
|
||||||
|
logger.info(f"Generations by: {gen_servers}")
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
model=self,
|
||||||
|
out_batches=np.array(
|
||||||
|
[self.tokenizer.encode(cgen["text"]) for cgen in generations]
|
||||||
|
),
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=True,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
72
modeling/inference_models/legacy_gpt2_hf.py
Normal file
72
modeling/inference_models/legacy_gpt2_hf.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from modeling.inference_models.hf_torch import HFTorchInferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGPT2HFTorchInferenceModel(HFTorchInferenceModel):
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
utils.koboldai_vars.lazy_load = False
|
||||||
|
|
||||||
|
model_path = None
|
||||||
|
|
||||||
|
for possible_config_path in [
|
||||||
|
utils.koboldai_vars.custmodpth,
|
||||||
|
os.path.join("models", utils.koboldai_vars.custmodpth),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
with open(
|
||||||
|
os.path.join(possible_config_path, "config.json"), "r"
|
||||||
|
) as file:
|
||||||
|
self.model_config = json.load(file)
|
||||||
|
model_path = possible_config_path
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not model_path:
|
||||||
|
raise RuntimeError("Empty model_path!")
|
||||||
|
|
||||||
|
with self._maybe_use_float16():
|
||||||
|
try:
|
||||||
|
self.model = GPT2LMHeadModel.from_pretrained(
|
||||||
|
utils.koboldai_vars.custmodpth,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
)
|
||||||
|
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||||
|
utils.koboldai_vars.custmodpth,
|
||||||
|
revision=utils.koboldai_vars.revision,
|
||||||
|
cache_dir="cache",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "out of memory" in traceback.format_exc().lower():
|
||||||
|
raise RuntimeError(
|
||||||
|
"One of your GPUs ran out of memory when KoboldAI tried to load your model."
|
||||||
|
) from e
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if save_model:
|
||||||
|
self.model.save_pretrained(
|
||||||
|
self.get_local_model_path(ignore_existance=True),
|
||||||
|
max_shard_size="500MiB",
|
||||||
|
)
|
||||||
|
self.tokenizer.save_pretrained(
|
||||||
|
self.get_local_model_path(ignore_existance=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
utils.koboldai_vars.modeldim = self.get_hidden_size()
|
||||||
|
|
||||||
|
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
|
||||||
|
if utils.koboldai_vars.hascuda and utils.koboldai_vars.usegpu:
|
||||||
|
self.model = self.model.half().to(utils.koboldai_vars.gpu_device)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to("cpu").float()
|
||||||
|
|
||||||
|
self.patch_causal_lm()
|
98
modeling/inference_models/openai.py
Normal file
98
modeling/inference_models/openai.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from modeling.inference_model import (
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSettings,
|
||||||
|
InferenceModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAPIError(Exception):
|
||||||
|
def __init__(self, error_type: str, error_message) -> None:
|
||||||
|
super().__init__(f"{error_type}: {error_message}")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAPIInferenceModel(InferenceModel):
|
||||||
|
"""InferenceModel for interfacing with OpenAI's generation API."""
|
||||||
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
|
self.tokenizer = self._get_tokenizer("gpt2")
|
||||||
|
|
||||||
|
def _raw_generate(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Union[List[int], torch.Tensor],
|
||||||
|
max_new: int,
|
||||||
|
gen_settings: GenerationSettings,
|
||||||
|
single_line: bool = False,
|
||||||
|
batch_count: int = 1,
|
||||||
|
) -> GenerationResult:
|
||||||
|
# Taken mainly from oairequest()
|
||||||
|
|
||||||
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
# Store context in memory to use it for comparison with generated content
|
||||||
|
utils.koboldai_vars.lastctx = decoded_prompt
|
||||||
|
|
||||||
|
# Build request JSON data
|
||||||
|
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
|
||||||
|
# as the koboldai_vars.model will always be OAI
|
||||||
|
if "GooseAI" in utils.koboldai_vars.configname:
|
||||||
|
reqdata = {
|
||||||
|
"prompt": decoded_prompt,
|
||||||
|
"max_tokens": max_new,
|
||||||
|
"temperature": gen_settings.temp,
|
||||||
|
"top_a": gen_settings.top_a,
|
||||||
|
"top_p": gen_settings.top_p,
|
||||||
|
"top_k": gen_settings.top_k,
|
||||||
|
"tfs": gen_settings.tfs,
|
||||||
|
"typical_p": gen_settings.typical,
|
||||||
|
"repetition_penalty": gen_settings.rep_pen,
|
||||||
|
"repetition_penalty_slope": gen_settings.rep_pen_slope,
|
||||||
|
"repetition_penalty_range": gen_settings.rep_pen_range,
|
||||||
|
"n": batch_count,
|
||||||
|
# TODO: Implement streaming
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
reqdata = {
|
||||||
|
"prompt": decoded_prompt,
|
||||||
|
"max_tokens": max_new,
|
||||||
|
"temperature": gen_settings.temp,
|
||||||
|
"top_p": gen_settings.top_p,
|
||||||
|
"frequency_penalty": gen_settings.rep_pen,
|
||||||
|
"n": batch_count,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
req = requests.post(
|
||||||
|
utils.koboldai_vars.oaiurl,
|
||||||
|
json=reqdata,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + utils.koboldai_vars.oaiapikey,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
j = req.json()
|
||||||
|
|
||||||
|
if not req.ok:
|
||||||
|
# Send error message to web client
|
||||||
|
if "error" in j:
|
||||||
|
error_type = j["error"]["type"]
|
||||||
|
error_message = j["error"]["message"]
|
||||||
|
else:
|
||||||
|
error_type = "Unknown"
|
||||||
|
error_message = "Unknown"
|
||||||
|
raise OpenAIAPIError(error_type, error_message)
|
||||||
|
|
||||||
|
outputs = [out["text"] for out in j["choices"]]
|
||||||
|
return GenerationResult(
|
||||||
|
model=self,
|
||||||
|
out_batches=np.array([self.tokenizer.encode(x) for x in outputs]),
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=True,
|
||||||
|
single_line=single_line,
|
||||||
|
)
|
133
modeling/patches.py
Normal file
133
modeling/patches.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import requests
|
||||||
|
from typing import Iterable, List
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
PreTrainedModel,
|
||||||
|
modeling_utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers_download():
|
||||||
|
def http_get(
|
||||||
|
url: str,
|
||||||
|
temp_file,
|
||||||
|
proxies=None,
|
||||||
|
resume_size=0,
|
||||||
|
headers=None,
|
||||||
|
file_name=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download remote file. Do not gobble up errors.
|
||||||
|
"""
|
||||||
|
headers = copy.deepcopy(headers)
|
||||||
|
if resume_size > 0:
|
||||||
|
headers["Range"] = f"bytes={resume_size}-"
|
||||||
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
|
transformers.utils.hub._raise_for_status(r)
|
||||||
|
content_length = r.headers.get("Content-Length")
|
||||||
|
total = (
|
||||||
|
resume_size + int(content_length) if content_length is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||||
|
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||||
|
if url[-11:] != "config.json":
|
||||||
|
progress = tqdm.tqdm(
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
total=total,
|
||||||
|
initial=resume_size,
|
||||||
|
desc=f"Downloading {file_name}"
|
||||||
|
if file_name is not None
|
||||||
|
else "Downloading",
|
||||||
|
file=utils.UIProgressBarFile(),
|
||||||
|
)
|
||||||
|
utils.koboldai_vars.status_message = "Download Model"
|
||||||
|
utils.koboldai_vars.total_download_chunks = total
|
||||||
|
|
||||||
|
for chunk in r.iter_content(chunk_size=1024):
|
||||||
|
if chunk: # filter out keep-alive new chunks
|
||||||
|
if url[-11:] != "config.json":
|
||||||
|
progress.update(len(chunk))
|
||||||
|
utils.koboldai_vars.downloaded_chunks += len(chunk)
|
||||||
|
temp_file.write(chunk)
|
||||||
|
|
||||||
|
if url[-11:] != "config.json":
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
utils.koboldai_vars.status_message = ""
|
||||||
|
|
||||||
|
transformers.utils.hub.http_get = http_get
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers_loader() -> None:
|
||||||
|
"""
|
||||||
|
Patch the Transformers loader to use aria2 and our shard tracking.
|
||||||
|
Universal for TPU/MTJ and Torch.
|
||||||
|
"""
|
||||||
|
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
utils.koboldai_vars.fp32_model = False
|
||||||
|
utils.num_shards = None
|
||||||
|
utils.current_shard = 0
|
||||||
|
utils.from_pretrained_model_name = pretrained_model_name_or_path
|
||||||
|
utils.from_pretrained_index_filename = None
|
||||||
|
utils.from_pretrained_kwargs = kwargs
|
||||||
|
utils.bar = None
|
||||||
|
if not utils.args.no_aria2:
|
||||||
|
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||||
|
return old_from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path, *model_args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(PreTrainedModel, "_kai_patched"):
|
||||||
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
PreTrainedModel._kai_patched = True
|
||||||
|
|
||||||
|
if hasattr(modeling_utils, "get_checkpoint_shard_files"):
|
||||||
|
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||||
|
|
||||||
|
def new_get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||||
|
):
|
||||||
|
utils.num_shards = utils.get_num_shards(index_filename)
|
||||||
|
utils.from_pretrained_index_filename = index_filename
|
||||||
|
return old_get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path, index_filename, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers_generation() -> None:
|
||||||
|
# Not sure why this global is needed...
|
||||||
|
global transformers
|
||||||
|
|
||||||
|
# Allow bad words filter to ban <|endoftext|> token
|
||||||
|
import transformers.generation.logits_process
|
||||||
|
|
||||||
|
def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int):
|
||||||
|
return new_init.old_init(self, bad_words_ids, -1)
|
||||||
|
|
||||||
|
new_init.old_init = (
|
||||||
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__
|
||||||
|
)
|
||||||
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||||
|
|
||||||
|
|
||||||
|
def patch_transformers() -> None:
|
||||||
|
patch_transformers_download()
|
||||||
|
patch_transformers_loader()
|
||||||
|
|
||||||
|
# Doesn't do anything for TPU
|
||||||
|
patch_transformers_generation()
|
27
modeling/post_token_hooks.py
Normal file
27
modeling/post_token_hooks.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from modeling.inference_model import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class PostTokenHooks:
|
||||||
|
@staticmethod
|
||||||
|
def stream_tokens(
|
||||||
|
model: InferenceModel,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
) -> None:
|
||||||
|
if not model.gen_state["do_streaming"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.output_streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = [
|
||||||
|
utils.applyoutputformatting(
|
||||||
|
utils.decodenewlines(model.tokenizer.decode(x[-1])),
|
||||||
|
no_sentence_trimming=True,
|
||||||
|
no_single_line=True,
|
||||||
|
)
|
||||||
|
for x in input_ids
|
||||||
|
]
|
||||||
|
utils.koboldai_vars.actions.stream_tokens(data)
|
117
modeling/stoppers.py
Normal file
117
modeling/stoppers.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from modeling.inference_model import (
|
||||||
|
InferenceModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Stoppers:
|
||||||
|
@staticmethod
|
||||||
|
def core_stopper(
|
||||||
|
model: InferenceModel,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
) -> bool:
|
||||||
|
if not utils.koboldai_vars.inference_config.do_core:
|
||||||
|
return False
|
||||||
|
|
||||||
|
utils.koboldai_vars.generated_tkns += 1
|
||||||
|
|
||||||
|
if (
|
||||||
|
not utils.koboldai_vars.standalone
|
||||||
|
and utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
and utils.koboldai_vars.generated_tkns
|
||||||
|
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if utils.koboldai_vars.abort or (
|
||||||
|
utils.koboldai_vars.inference_config.stop_at_genamt
|
||||||
|
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
||||||
|
):
|
||||||
|
utils.koboldai_vars.abort = False
|
||||||
|
model.gen_state["regeneration_required"] = False
|
||||||
|
model.gen_state["halt"] = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if utils.koboldai_vars.standalone:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
model.gen_state[
|
||||||
|
"regeneration_required"
|
||||||
|
] = utils.koboldai_vars.lua_koboldbridge.regeneration_required
|
||||||
|
model.gen_state["halt"] = not utils.koboldai_vars.lua_koboldbridge.generating
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
for i in (
|
||||||
|
range(utils.koboldai_vars.numseqs)
|
||||||
|
if not utils.koboldai_vars.alt_multi_gen
|
||||||
|
else range(1)
|
||||||
|
):
|
||||||
|
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
|
||||||
|
utils.koboldai_vars.generated_tkns
|
||||||
|
] = int(input_ids[i, -1].item())
|
||||||
|
|
||||||
|
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dynamic_wi_scanner(
|
||||||
|
model: InferenceModel,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
) -> bool:
|
||||||
|
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.dynamicscan:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
|
||||||
|
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
|
||||||
|
print(model.tokenizer.decode(input_ids.shape[0]))
|
||||||
|
|
||||||
|
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
|
||||||
|
|
||||||
|
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
||||||
|
for i, t in enumerate(tail):
|
||||||
|
decoded = utils.decodenewlines(model.tokenizer.decode(t))
|
||||||
|
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||||
|
submitted_text=decoded, send_context=False
|
||||||
|
)
|
||||||
|
found = list(
|
||||||
|
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
|
||||||
|
)
|
||||||
|
if found:
|
||||||
|
print("FOUNDWI", found)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chat_mode_stopper(
|
||||||
|
model: InferenceModel,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
) -> bool:
|
||||||
|
if not utils.koboldai_vars.chatmode:
|
||||||
|
return False
|
||||||
|
|
||||||
|
data = [model.tokenizer.decode(x) for x in input_ids]
|
||||||
|
# null_character = model.tokenizer.encode(chr(0))[0]
|
||||||
|
if "completed" not in model.gen_state:
|
||||||
|
model.gen_state["completed"] = [False] * len(input_ids)
|
||||||
|
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
if (
|
||||||
|
data[i][-1 * (len(utils.koboldai_vars.chatname) + 1) :]
|
||||||
|
== utils.koboldai_vars.chatname + ":"
|
||||||
|
):
|
||||||
|
model.gen_state["completed"][i] = True
|
||||||
|
if all(model.gen_state["completed"]):
|
||||||
|
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
|
||||||
|
del model.gen_state["completed"]
|
||||||
|
return True
|
||||||
|
return False
|
@@ -54,7 +54,7 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
|||||||
from mesh_transformer.util import to_bf16
|
from mesh_transformer.util import to_bf16
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import warpers
|
import modeling.warpers as warpers
|
||||||
|
|
||||||
socketio = None
|
socketio = None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user