Additional validation for soft_prompt in API
This commit is contained in:
parent
de1e8f266a
commit
ce064168e3
21
aiserver.py
21
aiserver.py
|
@ -6758,13 +6758,26 @@ class SamplerSettingsSchema(KoboldSchema):
|
||||||
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
|
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
|
||||||
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
|
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
|
||||||
|
|
||||||
|
def soft_prompt_validator(soft_prompt: str):
|
||||||
|
if len(soft_prompt.strip()) == 0:
|
||||||
|
return
|
||||||
|
if not vars.allowsp:
|
||||||
|
raise ValidationError("Cannot use soft prompts with current backend.")
|
||||||
|
if any(q in soft_prompt for q in ("/", "\\")):
|
||||||
|
return
|
||||||
|
z, _, _, _, _ = fileops.checksp(soft_prompt.strip(), vars.modeldim)
|
||||||
|
if isinstance(z, int):
|
||||||
|
raise ValidationError("Must be a valid soft prompt name.")
|
||||||
|
z.close()
|
||||||
|
return True
|
||||||
|
|
||||||
class GenerationInputSchema(SamplerSettingsSchema):
|
class GenerationInputSchema(SamplerSettingsSchema):
|
||||||
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
||||||
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
||||||
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||||
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||||
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."})
|
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
|
||||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
||||||
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
||||||
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
||||||
|
@ -6828,11 +6841,11 @@ def _generate_text(body: GenerationInputSchema):
|
||||||
saved_settings[key] = getattr(entry[0], entry[1])
|
saved_settings[key] = getattr(entry[0], entry[1])
|
||||||
setattr(entry[0], entry[1], getattr(body, key))
|
setattr(entry[0], entry[1], getattr(body, key))
|
||||||
try:
|
try:
|
||||||
if getattr(body, "soft_prompt", None) is not None:
|
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||||
if any(q in body.soft_prompt for q in ("/", "\\")):
|
if any(q in body.soft_prompt for q in ("/", "\\")):
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
old_spfilename = vars.spfilename
|
old_spfilename = vars.spfilename
|
||||||
spRequest(body.soft_prompt)
|
spRequest(body.soft_prompt.strip())
|
||||||
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
|
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
|
||||||
output = {"results": [{"text": txt} for txt in genout]}
|
output = {"results": [{"text": txt} for txt in genout]}
|
||||||
finally:
|
finally:
|
||||||
|
@ -6847,7 +6860,7 @@ def _generate_text(body: GenerationInputSchema):
|
||||||
setattr(entry[0], entry[1], saved_settings[key])
|
setattr(entry[0], entry[1], saved_settings[key])
|
||||||
vars.disable_set_aibusy = disable_set_aibusy
|
vars.disable_set_aibusy = disable_set_aibusy
|
||||||
vars.standalone = _standalone
|
vars.standalone = _standalone
|
||||||
if getattr(body, "soft_prompt", None) is not None:
|
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||||
spRequest(old_spfilename)
|
spRequest(old_spfilename)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue