Files
KoboldAI-Client/modeling/stoppers.py
Concedo 9705b7b79c increase API version (+1 squashed commits)
Squashed commits:

[c168c08] Added stop sequences functionality for API calls
2023-04-15 18:09:53 +08:00

170 lines
6.0 KiB
Python

from __future__ import annotations
import torch
import utils
from modeling.inference_model import (
InferenceModel,
)
class Stoppers:
@staticmethod
def core_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_core:
return False
utils.koboldai_vars.generated_tkns += 1
if (
not utils.koboldai_vars.standalone
and utils.koboldai_vars.lua_koboldbridge.generated_cols
and utils.koboldai_vars.generated_tkns
!= utils.koboldai_vars.lua_koboldbridge.generated_cols
):
raise RuntimeError(
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
)
if utils.koboldai_vars.abort or (
utils.koboldai_vars.inference_config.stop_at_genamt
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
):
utils.koboldai_vars.abort = False
model.gen_state["regeneration_required"] = False
model.gen_state["halt"] = False
return True
if utils.koboldai_vars.standalone:
return False
assert input_ids.ndim == 2
model.gen_state[
"regeneration_required"
] = utils.koboldai_vars.lua_koboldbridge.regeneration_required
model.gen_state["halt"] = not utils.koboldai_vars.lua_koboldbridge.generating
utils.koboldai_vars.lua_koboldbridge.regeneration_required = False
for i in (
range(utils.koboldai_vars.numseqs)
if not utils.koboldai_vars.alt_multi_gen
else range(1)
):
utils.koboldai_vars.lua_koboldbridge.generated[i + 1][
utils.koboldai_vars.generated_tkns
] = int(input_ids[i, -1].item())
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
@staticmethod
def dynamic_wi_scanner(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
return False
if not utils.koboldai_vars.dynamicscan:
return False
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
print(model.tokenizer.decode(input_ids.shape[0]))
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(model.tokenizer.decode(t))
_, _, _, found = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded, send_context=False
)
found = list(
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
)
if found:
print("FOUNDWI", found)
return True
return False
@staticmethod
def chat_mode_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.chatmode:
return False
data = [model.tokenizer.decode(x) for x in input_ids]
# null_character = model.tokenizer.encode(chr(0))[0]
if "completed" not in model.gen_state:
model.gen_state["completed"] = [False] * len(input_ids)
for i in range(len(input_ids)):
if (
data[i][-1 * (len(utils.koboldai_vars.chatname) + 1) :]
== utils.koboldai_vars.chatname + ":"
):
model.gen_state["completed"][i] = True
if all(model.gen_state["completed"]):
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
del model.gen_state["completed"]
return True
return False
@staticmethod
def stop_sequence_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
data = [model.tokenizer.decode(x) for x in input_ids]
# null_character = model.tokenizer.encode(chr(0))[0]
if "completed" not in model.gen_state:
model.gen_state["completed"] = [False] * len(input_ids)
#one issue is that the stop sequence may not actual align with the end of token
#if its a subsection of a longer token
for stopper in utils.koboldai_vars.stop_sequence:
for i in range(len(input_ids)):
if (
data[i][-1 * (len(stopper)) :]
== stopper
):
model.gen_state["completed"][i] = True
if all(model.gen_state["completed"]):
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
del model.gen_state["completed"]
return True
return False
@staticmethod
def singleline_stopper(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
if not utils.koboldai_vars.singleline:
return False
# Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain
# a newline.
if "newline_in_sequence" not in model.gen_state:
model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
if model.tokenizer.decode(batch_sequence[-1]) == "\n":
model.gen_state["newline_in_sequence"][sequence_idx] = True
if all(model.gen_state["newline_in_sequence"]):
del model.gen_state["newline_in_sequence"]
return True
return False