mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Add singleline_stopper and fix stopper code
singleline_stopper adapted from MasterAibo in 0ba7ac9
This commit is contained in:
@@ -63,9 +63,13 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
self.low_mem = low_mem
|
self.low_mem = low_mem
|
||||||
|
|
||||||
self.post_token_hooks = [
|
self.post_token_hooks = [
|
||||||
Stoppers.core_stopper,
|
|
||||||
PostTokenHooks.stream_tokens,
|
PostTokenHooks.stream_tokens,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.stopper_hooks = [
|
||||||
|
Stoppers.core_stopper,
|
||||||
Stoppers.dynamic_wi_scanner,
|
Stoppers.dynamic_wi_scanner,
|
||||||
|
Stoppers.singleline_stopper,
|
||||||
Stoppers.chat_mode_stopper,
|
Stoppers.chat_mode_stopper,
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -104,7 +108,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
self._post_token_gen(input_ids)
|
self._post_token_gen(input_ids)
|
||||||
|
|
||||||
for stopper in self.stopper_hooks:
|
for stopper in self.stopper_hooks:
|
||||||
do_stop = stopper(input_ids)
|
do_stop = stopper(self, input_ids)
|
||||||
if do_stop:
|
if do_stop:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@@ -115,3 +115,32 @@ class Stoppers:
|
|||||||
del model.gen_state["completed"]
|
del model.gen_state["completed"]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def singleline_stopper(
|
||||||
|
model: InferenceModel,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
) -> bool:
|
||||||
|
"""If singleline mode is enabled, it's pointless to generate output beyond the first newline."""
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.singleline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Keep track of presence of newlines in each sequence; we cannot stop a
|
||||||
|
# batch member individually, so we must wait for all of them to contain
|
||||||
|
# a newline.
|
||||||
|
if "newline_in_sequence" not in model.gen_state:
|
||||||
|
model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
|
||||||
|
|
||||||
|
print(model.gen_state["newline_in_sequence"])
|
||||||
|
|
||||||
|
for sequence_idx, batch_sequence in enumerate(input_ids):
|
||||||
|
if model.tokenizer.decode(batch_sequence[-1]) == "\n":
|
||||||
|
model.gen_state["newline_in_sequence"][sequence_idx] = True
|
||||||
|
|
||||||
|
if all(model.gen_state["newline_in_sequence"]):
|
||||||
|
del model.gen_state["newline_in_sequence"]
|
||||||
|
print("OUT")
|
||||||
|
return True
|
||||||
|
print("nah its ok")
|
||||||
|
return False
|
||||||
|
Reference in New Issue
Block a user