mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Generation Mode: UNTIL_EOS
This mode enables the EOS token and will generate infinitely until hitting it.
This commit is contained in:
@@ -585,19 +585,21 @@ class InferenceModel:
|
|||||||
"wi_scanner_excluded_keys", set()
|
"wi_scanner_excluded_keys", set()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.gen_state["allow_eos"] = False
|
||||||
|
|
||||||
temp_stoppers = []
|
temp_stoppers = []
|
||||||
|
|
||||||
|
if gen_mode not in self.get_supported_gen_modes():
|
||||||
|
gen_mode = GenerationMode.STANDARD
|
||||||
|
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")
|
||||||
|
|
||||||
if gen_mode == GenerationMode.FOREVER:
|
if gen_mode == GenerationMode.FOREVER:
|
||||||
if self.capabilties.stopper_hooks:
|
|
||||||
self.gen_state["stop_at_genamt"] = False
|
self.gen_state["stop_at_genamt"] = False
|
||||||
max_new = 1e7
|
max_new = 1e7
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"User requested infinite generation on model that doesn't support stop hooks. Recipe for disaster!"
|
|
||||||
)
|
|
||||||
elif gen_mode == GenerationMode.UNTIL_EOS:
|
elif gen_mode == GenerationMode.UNTIL_EOS:
|
||||||
# Still need to unban
|
self.gen_state["allow_eos"] = True
|
||||||
raise NotImplementedError()
|
self.gen_state["stop_at_genamt"] = False
|
||||||
|
max_new = 1e7
|
||||||
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
|
elif gen_mode == GenerationMode.UNTIL_NEWLINE:
|
||||||
# TODO: Look into replacing `single_line` with `generation_mode`
|
# TODO: Look into replacing `single_line` with `generation_mode`
|
||||||
temp_stoppers.append(Stoppers.newline_stopper)
|
temp_stoppers.append(Stoppers.newline_stopper)
|
||||||
@@ -668,11 +670,11 @@ class InferenceModel:
|
|||||||
Returns:
|
Returns:
|
||||||
List[GenerationMode]: A list of compatible `GenerationMode`s.
|
List[GenerationMode]: A list of compatible `GenerationMode`s.
|
||||||
"""
|
"""
|
||||||
ret = []
|
ret = [GenerationMode.STANDARD]
|
||||||
|
|
||||||
if self.capabilties.stopper_hooks:
|
if self.capabilties.stopper_hooks:
|
||||||
ret += [
|
ret += [
|
||||||
GenerationMode.FOREVER,
|
GenerationMode.FOREVER,
|
||||||
GenerationMode.UNTIL_EOS,
|
|
||||||
GenerationMode.UNTIL_NEWLINE,
|
GenerationMode.UNTIL_NEWLINE,
|
||||||
GenerationMode.UNTIL_SENTENCE_END,
|
GenerationMode.UNTIL_SENTENCE_END,
|
||||||
]
|
]
|
||||||
|
@@ -31,6 +31,7 @@ from modeling.stoppers import Stoppers
|
|||||||
from modeling.post_token_hooks import PostTokenHooks
|
from modeling.post_token_hooks import PostTokenHooks
|
||||||
from modeling.inference_models.hf import HFInferenceModel
|
from modeling.inference_models.hf import HFInferenceModel
|
||||||
from modeling.inference_model import (
|
from modeling.inference_model import (
|
||||||
|
GenerationMode,
|
||||||
GenerationResult,
|
GenerationResult,
|
||||||
GenerationSettings,
|
GenerationSettings,
|
||||||
ModelCapabilities,
|
ModelCapabilities,
|
||||||
@@ -254,7 +255,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
kwargs["logits_warper"] = new_get_logits_warper(
|
kwargs["logits_warper"] = new_get_logits_warper(
|
||||||
beams=1,
|
beams=1,
|
||||||
)
|
)
|
||||||
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
|
|
||||||
|
if (
|
||||||
|
utils.koboldai_vars.newlinemode in ["s", "ns"]
|
||||||
|
and not m_self.gen_state["allow_eos"]
|
||||||
|
):
|
||||||
kwargs["eos_token_id"] = -1
|
kwargs["eos_token_id"] = -1
|
||||||
kwargs.setdefault("pad_token_id", 2)
|
kwargs.setdefault("pad_token_id", 2)
|
||||||
return new_sample.old_sample(self, *args, **kwargs)
|
return new_sample.old_sample(self, *args, **kwargs)
|
||||||
@@ -605,3 +610,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
self.breakmodel = False
|
self.breakmodel = False
|
||||||
self.usegpu = False
|
self.usegpu = False
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def get_supported_gen_modes(self) -> List[GenerationMode]:
|
||||||
|
# This changes a torch patch to disallow eos as a bad word.
|
||||||
|
return super().get_supported_gen_modes() + [
|
||||||
|
GenerationMode.UNTIL_EOS
|
||||||
|
]
|
Reference in New Issue
Block a user