Generation Mode: UNTIL_EOS

This mode enables the EOS token and will generate infinitely until
hitting it.
This commit is contained in:
somebody
2023-07-21 15:36:32 -05:00
parent c78401bd12
commit e5d0a597a1
2 changed files with 25 additions and 12 deletions

View File

@@ -585,19 +585,21 @@ class InferenceModel:
"wi_scanner_excluded_keys", set() "wi_scanner_excluded_keys", set()
) )
self.gen_state["allow_eos"] = False
temp_stoppers = [] temp_stoppers = []
if gen_mode not in self.get_supported_gen_modes():
gen_mode = GenerationMode.STANDARD
logger.warning(f"User requested unsupported GenerationMode '{gen_mode}'!")
if gen_mode == GenerationMode.FOREVER: if gen_mode == GenerationMode.FOREVER:
if self.capabilties.stopper_hooks: self.gen_state["stop_at_genamt"] = False
self.gen_state["stop_at_genamt"] = False max_new = 1e7
max_new = 1e7
else:
logger.warning(
"User requested infinite generation on model that doesn't support stop hooks. Recipe for disaster!"
)
elif gen_mode == GenerationMode.UNTIL_EOS: elif gen_mode == GenerationMode.UNTIL_EOS:
# Still need to unban self.gen_state["allow_eos"] = True
raise NotImplementedError() self.gen_state["stop_at_genamt"] = False
max_new = 1e7
elif gen_mode == GenerationMode.UNTIL_NEWLINE: elif gen_mode == GenerationMode.UNTIL_NEWLINE:
# TODO: Look into replacing `single_line` with `generation_mode` # TODO: Look into replacing `single_line` with `generation_mode`
temp_stoppers.append(Stoppers.newline_stopper) temp_stoppers.append(Stoppers.newline_stopper)
@@ -668,11 +670,11 @@ class InferenceModel:
Returns: Returns:
List[GenerationMode]: A list of compatible `GenerationMode`s. List[GenerationMode]: A list of compatible `GenerationMode`s.
""" """
ret = [] ret = [GenerationMode.STANDARD]
if self.capabilties.stopper_hooks: if self.capabilties.stopper_hooks:
ret += [ ret += [
GenerationMode.FOREVER, GenerationMode.FOREVER,
GenerationMode.UNTIL_EOS,
GenerationMode.UNTIL_NEWLINE, GenerationMode.UNTIL_NEWLINE,
GenerationMode.UNTIL_SENTENCE_END, GenerationMode.UNTIL_SENTENCE_END,
] ]

View File

@@ -31,6 +31,7 @@ from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_models.hf import HFInferenceModel from modeling.inference_models.hf import HFInferenceModel
from modeling.inference_model import ( from modeling.inference_model import (
GenerationMode,
GenerationResult, GenerationResult,
GenerationSettings, GenerationSettings,
ModelCapabilities, ModelCapabilities,
@@ -254,7 +255,11 @@ class HFTorchInferenceModel(HFInferenceModel):
kwargs["logits_warper"] = new_get_logits_warper( kwargs["logits_warper"] = new_get_logits_warper(
beams=1, beams=1,
) )
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
if (
utils.koboldai_vars.newlinemode in ["s", "ns"]
and not m_self.gen_state["allow_eos"]
):
kwargs["eos_token_id"] = -1 kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2) kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs) return new_sample.old_sample(self, *args, **kwargs)
@@ -605,3 +610,9 @@ class HFTorchInferenceModel(HFInferenceModel):
self.breakmodel = False self.breakmodel = False
self.usegpu = False self.usegpu = False
return return
def get_supported_gen_modes(self) -> List[GenerationMode]:
# This changes a torch patch to disallow eos as a bad word.
return super().get_supported_gen_modes() + [
GenerationMode.UNTIL_EOS
]