Additional validation for soft_prompt in API

This commit is contained in:
vfbd 2022-08-08 13:52:07 -04:00
parent de1e8f266a
commit ce064168e3
1 changed files with 17 additions and 4 deletions

View File

@ -6758,13 +6758,26 @@ class SamplerSettingsSchema(KoboldSchema):
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
def soft_prompt_validator(soft_prompt: str):
if len(soft_prompt.strip()) == 0:
return
if not vars.allowsp:
raise ValidationError("Cannot use soft prompts with current backend.")
if any(q in soft_prompt for q in ("/", "\\")):
return
z, _, _, _, _ = fileops.checksp(soft_prompt.strip(), vars.modeldim)
if isinstance(z, int):
raise ValidationError("Must be a valid soft prompt name.")
z.close()
return True
class GenerationInputSchema(SamplerSettingsSchema):
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."})
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
@ -6828,11 +6841,11 @@ def _generate_text(body: GenerationInputSchema):
saved_settings[key] = getattr(entry[0], entry[1])
setattr(entry[0], entry[1], getattr(body, key))
try:
if getattr(body, "soft_prompt", None) is not None:
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
if any(q in body.soft_prompt for q in ("/", "\\")):
raise RuntimeError
old_spfilename = vars.spfilename
spRequest(body.soft_prompt)
spRequest(body.soft_prompt.strip())
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
output = {"results": [{"text": txt} for txt in genout]}
finally:
@ -6847,7 +6860,7 @@ def _generate_text(body: GenerationInputSchema):
setattr(entry[0], entry[1], saved_settings[key])
vars.disable_set_aibusy = disable_set_aibusy
vars.standalone = _standalone
if getattr(body, "soft_prompt", None) is not None:
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
spRequest(old_spfilename)
set_aibusy(0)
return output