Add basic support for some of the quick stoppers

This commit is contained in:
somebody
2023-07-21 13:27:30 -05:00
parent 6cf63f781a
commit 3a43b254b8
4 changed files with 137 additions and 34 deletions

View File

@@ -3,15 +3,12 @@ from __future__ import annotations
import torch
import utils
from modeling.inference_model import (
InferenceModel,
)
from modeling import inference_model
class Stoppers:
@staticmethod
def core_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_core:
@@ -62,7 +59,7 @@ class Stoppers:
@staticmethod
def dynamic_wi_scanner(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
@@ -93,7 +90,7 @@ class Stoppers:
@staticmethod
def chat_mode_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.chatmode:
@@ -118,7 +115,7 @@ class Stoppers:
@staticmethod
def stop_sequence_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
@@ -145,14 +142,22 @@ class Stoppers:
@staticmethod
def singleline_stopper(
model: InferenceModel,
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
"""Stop on occurances of newlines **if singleline is enabled**."""
# It might be better just to do this further up the line
if not utils.koboldai_vars.singleline:
return False
return Stoppers.newline_stopper(model, input_ids)
@staticmethod
def newline_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stop on occurances of newlines."""
# Keep track of presence of newlines in each sequence; we cannot stop a
# batch member individually, so we must wait for all of them to contain
# a newline.
@@ -167,3 +172,30 @@ class Stoppers:
del model.gen_state["newline_in_sequence"]
return True
return False
@staticmethod
def sentence_end_stopper(
model: inference_model.InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
"""Stops at the end of sentences."""
# TODO: Make this more robust
SENTENCE_ENDS = [".", "?", "!"]
# We need to keep track of stopping for each batch, since we can't stop
# one individually.
if "sentence_end_in_sequence" not in model.gen_state:
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids):
decoded = model.tokenizer.decode(batch_sequence[-1])
for end in SENTENCE_ENDS:
if end in decoded:
model.gen_state["sentence_end_sequence"][sequence_idx] = True
break
if all(model.gen_state["sentence_end_sequence"]):
del model.gen_state["sentence_end_sequence"]
return True
return False