mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add basic support for some of the quick stoppers
This commit is contained in:
@@ -3,15 +3,12 @@ from __future__ import annotations
|
||||
import torch
|
||||
|
||||
import utils
|
||||
from modeling.inference_model import (
|
||||
InferenceModel,
|
||||
)
|
||||
|
||||
from modeling import inference_model
|
||||
|
||||
class Stoppers:
|
||||
@staticmethod
|
||||
def core_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_core:
|
||||
@@ -62,7 +59,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def dynamic_wi_scanner(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||
@@ -93,7 +90,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def chat_mode_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.chatmode:
|
||||
@@ -118,7 +115,7 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def stop_sequence_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
|
||||
@@ -145,14 +142,22 @@ class Stoppers:
|
||||
|
||||
@staticmethod
|
||||
def singleline_stopper(
|
||||
model: InferenceModel,
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
|
||||
"""Stop on occurances of newlines **if singleline is enabled**."""
|
||||
|
||||
# It might be better just to do this further up the line
|
||||
if not utils.koboldai_vars.singleline:
|
||||
return False
|
||||
return Stoppers.newline_stopper(model, input_ids)
|
||||
|
||||
@staticmethod
|
||||
def newline_stopper(
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""Stop on occurances of newlines."""
|
||||
# Keep track of presence of newlines in each sequence; we cannot stop a
|
||||
# batch member individually, so we must wait for all of them to contain
|
||||
# a newline.
|
||||
@@ -167,3 +172,30 @@ class Stoppers:
|
||||
del model.gen_state["newline_in_sequence"]
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def sentence_end_stopper(
|
||||
model: inference_model.InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
"""Stops at the end of sentences."""
|
||||
|
||||
# TODO: Make this more robust
|
||||
SENTENCE_ENDS = [".", "?", "!"]
|
||||
|
||||
# We need to keep track of stopping for each batch, since we can't stop
|
||||
# one individually.
|
||||
if "sentence_end_in_sequence" not in model.gen_state:
|
||||
model.gen_state["sentence_end_sequence"] = [False] * len(input_ids)
|
||||
|
||||
for sequence_idx, batch_sequence in enumerate(input_ids):
|
||||
decoded = model.tokenizer.decode(batch_sequence[-1])
|
||||
for end in SENTENCE_ENDS:
|
||||
if end in decoded:
|
||||
model.gen_state["sentence_end_sequence"][sequence_idx] = True
|
||||
break
|
||||
|
||||
if all(model.gen_state["sentence_end_sequence"]):
|
||||
del model.gen_state["sentence_end_sequence"]
|
||||
return True
|
||||
return False
|
Reference in New Issue
Block a user