Additional validation for soft_prompt in API
This commit is contained in:
parent
de1e8f266a
commit
ce064168e3
21
aiserver.py
21
aiserver.py
|
@ -6758,13 +6758,26 @@ class SamplerSettingsSchema(KoboldSchema):
|
|||
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
|
||||
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
|
||||
|
||||
def soft_prompt_validator(soft_prompt: str):
|
||||
if len(soft_prompt.strip()) == 0:
|
||||
return
|
||||
if not vars.allowsp:
|
||||
raise ValidationError("Cannot use soft prompts with current backend.")
|
||||
if any(q in soft_prompt for q in ("/", "\\")):
|
||||
return
|
||||
z, _, _, _, _ = fileops.checksp(soft_prompt.strip(), vars.modeldim)
|
||||
if isinstance(z, int):
|
||||
raise ValidationError("Must be a valid soft prompt name.")
|
||||
z.close()
|
||||
return True
|
||||
|
||||
class GenerationInputSchema(SamplerSettingsSchema):
|
||||
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
||||
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
||||
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."})
|
||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
|
||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
||||
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
||||
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
||||
|
@ -6828,11 +6841,11 @@ def _generate_text(body: GenerationInputSchema):
|
|||
saved_settings[key] = getattr(entry[0], entry[1])
|
||||
setattr(entry[0], entry[1], getattr(body, key))
|
||||
try:
|
||||
if getattr(body, "soft_prompt", None) is not None:
|
||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||
if any(q in body.soft_prompt for q in ("/", "\\")):
|
||||
raise RuntimeError
|
||||
old_spfilename = vars.spfilename
|
||||
spRequest(body.soft_prompt)
|
||||
spRequest(body.soft_prompt.strip())
|
||||
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
|
||||
output = {"results": [{"text": txt} for txt in genout]}
|
||||
finally:
|
||||
|
@ -6847,7 +6860,7 @@ def _generate_text(body: GenerationInputSchema):
|
|||
setattr(entry[0], entry[1], saved_settings[key])
|
||||
vars.disable_set_aibusy = disable_set_aibusy
|
||||
vars.standalone = _standalone
|
||||
if getattr(body, "soft_prompt", None) is not None:
|
||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||
spRequest(old_spfilename)
|
||||
set_aibusy(0)
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue