mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: And another refactor
This commit is contained in:
117
modeling/stoppers.py
Normal file
117
modeling/stoppers.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import utils
|
||||
from modeling.inference_model import (
|
||||
InferenceModel,
|
||||
)
|
||||
|
||||
|
||||
class Stoppers:
|
||||
@staticmethod
|
||||
def core_stopper(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_core:
|
||||
return False
|
||||
|
||||
utils.koboldai_vars.generated_tkns += 1
|
||||
|
||||
if (
|
||||
not utils.koboldai_vars.standalone
|
||||
and utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||
and utils.koboldai_vars.generated_tkns
|
||||
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.abort or (
|
||||
utils.koboldai_vars.inference_config.stop_at_genamt
|
||||
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
||||
):
|
||||
utils.koboldai_vars.abort = False
|
||||
model.gen_state["regeneration_required"] = False
|
||||
model.gen_state["halt"] = False
|
||||
return True
|
||||
|
||||
if utils.koboldai_vars.standalone:
|
||||
return False
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
model.gen_state[
|
||||
"regeneration_required"
|
||||
] = utils.koboldai_vars.lua_koboldbridge.regeneration_required
|
||||
model.gen_state["halt"] = not utils.koboldai_vars.lua_koboldbridge.generating
|
||||
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
for i in (
|
||||
range(utils.koboldai_vars.numseqs)
|
||||
if not utils.koboldai_vars.alt_multi_gen
|
||||
else range(1)
|
||||
):
|
||||
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
|
||||
utils.koboldai_vars.generated_tkns
|
||||
] = int(input_ids[i, -1].item())
|
||||
|
||||
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
|
||||
|
||||
@staticmethod
|
||||
def dynamic_wi_scanner(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||
return False
|
||||
|
||||
if not utils.koboldai_vars.dynamicscan:
|
||||
return False
|
||||
|
||||
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
|
||||
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
|
||||
print(model.tokenizer.decode(input_ids.shape[0]))
|
||||
|
||||
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
|
||||
|
||||
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = utils.decodenewlines(model.tokenizer.decode(t))
|
||||
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||
submitted_text=decoded, send_context=False
|
||||
)
|
||||
found = list(
|
||||
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
|
||||
)
|
||||
if found:
|
||||
print("FOUNDWI", found)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def chat_mode_stopper(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.chatmode:
|
||||
return False
|
||||
|
||||
data = [model.tokenizer.decode(x) for x in input_ids]
|
||||
# null_character = model.tokenizer.encode(chr(0))[0]
|
||||
if "completed" not in model.gen_state:
|
||||
model.gen_state["completed"] = [False] * len(input_ids)
|
||||
|
||||
for i in range(len(input_ids)):
|
||||
if (
|
||||
data[i][-1 * (len(utils.koboldai_vars.chatname) + 1) :]
|
||||
== utils.koboldai_vars.chatname + ":"
|
||||
):
|
||||
model.gen_state["completed"][i] = True
|
||||
if all(model.gen_state["completed"]):
|
||||
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
|
||||
del model.gen_state["completed"]
|
||||
return True
|
||||
return False
|
Reference in New Issue
Block a user