Unknown values in API input are now ignored instead of causing error

This commit is contained in:
vfbd 2022-08-08 13:17:53 -04:00
parent 3b56859c12
commit 596f619999
1 changed files with 13 additions and 9 deletions

View File

@ -402,9 +402,13 @@ from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec.ext.marshmallow.field_converter import make_min_max_attributes
from apispec_webframeworks.flask import FlaskPlugin
from marshmallow import Schema, fields, validate
from marshmallow import Schema, fields, validate, EXCLUDE
from marshmallow.exceptions import ValidationError
class KoboldSchema(Schema):
class Meta:
unknown = EXCLUDE # If there are unknown values in the input to an API endpoint, ignore them instead of raising error 422.
def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict:
# Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False
attributes = {}
@ -6669,11 +6673,11 @@ def get_files_folders(starting_folder):
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
class BasicErrorSchema(Schema):
class BasicErrorSchema(KoboldSchema):
msg: str = fields.String(required=True)
type: str = fields.String(required=True)
class OutOfMemoryErrorSchema(Schema):
class OutOfMemoryErrorSchema(KoboldSchema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_out_of_memory_response = """507:
@ -6708,7 +6712,7 @@ api_out_of_memory_response = """507:
msg: "KoboldAI ran out of memory."
type: out_of_memory.unknown.unknown"""
class ValidationErrorSchema(Schema):
class ValidationErrorSchema(KoboldSchema):
detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String()), required=True)
api_validation_error_response = """422:
@ -6717,7 +6721,7 @@ api_validation_error_response = """422:
application/json:
schema: ValidationErrorSchema"""
class ServerBusyErrorSchema(Schema):
class ServerBusyErrorSchema(KoboldSchema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_server_busy_response = """503:
@ -6730,7 +6734,7 @@ api_server_busy_response = """503:
msg: Server is busy; please try again later.
type: service_unavailable"""
class NotImplementedErrorSchema(Schema):
class NotImplementedErrorSchema(KoboldSchema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_not_implemented_response = """501:
@ -6743,7 +6747,7 @@ api_not_implemented_response = """501:
msg: API generation is not supported in read-only mode; please load a model and then try again.
type: not_implemented"""
class SamplerSettingsSchema(Schema):
class SamplerSettingsSchema(KoboldSchema):
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
@ -6771,10 +6775,10 @@ class GenerationInputSchema(SamplerSettingsSchema):
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
class GenerationResultSchema(Schema):
class GenerationResultSchema(KoboldSchema):
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
class GenerationOutputSchema(Schema):
class GenerationOutputSchema(KoboldSchema):
results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."})
def _generate_text(body: GenerationInputSchema):