From ce064168e3d9feab3cb427cde26fa47e79f8fba0 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 8 Aug 2022 13:52:07 -0400 Subject: [PATCH] Additional validation for soft_prompt in API --- aiserver.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/aiserver.py b/aiserver.py index b11f0e93..1a196bb9 100644 --- a/aiserver.py +++ b/aiserver.py @@ -6758,13 +6758,26 @@ class SamplerSettingsSchema(KoboldSchema): typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."}) temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."}) +def soft_prompt_validator(soft_prompt: str): + if len(soft_prompt.strip()) == 0: + return + if not vars.allowsp: + raise ValidationError("Cannot use soft prompts with current backend.") + if any(q in soft_prompt for q in ("/", "\\")): + return + z, _, _, _, _ = fileops.checksp(soft_prompt.strip(), vars.modeldim) + if isinstance(z, int): + raise ValidationError("Must be a valid soft prompt name.") + z.close() + return True + class GenerationInputSchema(SamplerSettingsSchema): prompt: str = fields.String(required=True, metadata={"description": "This is the submission."}) use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."}) use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) - soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}) + soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")]) max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."}) n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."}) disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."}) @@ -6828,11 +6841,11 @@ def _generate_text(body: GenerationInputSchema): saved_settings[key] = getattr(entry[0], entry[1]) setattr(entry[0], entry[1], getattr(body, key)) try: - if getattr(body, "soft_prompt", None) is not None: + if vars.allowsp and getattr(body, "soft_prompt", None) is not None: if any(q in body.soft_prompt for q in ("/", "\\")): raise RuntimeError old_spfilename = vars.spfilename - spRequest(body.soft_prompt) + spRequest(body.soft_prompt.strip()) genout = apiactionsubmit(body.prompt, use_memory=body.use_memory) output = {"results": [{"text": txt} for txt in genout]} finally: @@ -6847,7 +6860,7 @@ def _generate_text(body: GenerationInputSchema): setattr(entry[0], entry[1], saved_settings[key]) vars.disable_set_aibusy = disable_set_aibusy vars.standalone = _standalone - if getattr(body, "soft_prompt", None) is not None: + if vars.allowsp and getattr(body, "soft_prompt", None) is not None: spRequest(old_spfilename) set_aibusy(0) return output