From 596f61999970df10efea07476640d4ff7e994070 Mon Sep 17 00:00:00 2001 From: vfbd Date: Mon, 8 Aug 2022 13:17:53 -0400 Subject: [PATCH] Unknown values in API input are now ignored instead of causing error --- aiserver.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/aiserver.py b/aiserver.py index 0c7aa752..93f162dd 100644 --- a/aiserver.py +++ b/aiserver.py @@ -402,9 +402,13 @@ from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.field_converter import make_min_max_attributes from apispec_webframeworks.flask import FlaskPlugin -from marshmallow import Schema, fields, validate +from marshmallow import Schema, fields, validate, EXCLUDE from marshmallow.exceptions import ValidationError +class KoboldSchema(Schema): + class Meta: + unknown = EXCLUDE # If there are unknown values in the input to an API endpoint, ignore them instead of raising error 422. + def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict: # Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False attributes = {} @@ -6669,11 +6673,11 @@ def get_files_folders(starting_folder): socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True) -class BasicErrorSchema(Schema): +class BasicErrorSchema(KoboldSchema): msg: str = fields.String(required=True) type: str = fields.String(required=True) -class OutOfMemoryErrorSchema(Schema): +class OutOfMemoryErrorSchema(KoboldSchema): detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) api_out_of_memory_response = """507: @@ -6708,7 +6712,7 @@ api_out_of_memory_response = """507: msg: "KoboldAI ran out of memory." type: out_of_memory.unknown.unknown""" -class ValidationErrorSchema(Schema): +class ValidationErrorSchema(KoboldSchema): detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String()), required=True) api_validation_error_response = """422: @@ -6717,7 +6721,7 @@ api_validation_error_response = """422: application/json: schema: ValidationErrorSchema""" -class ServerBusyErrorSchema(Schema): +class ServerBusyErrorSchema(KoboldSchema): detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) api_server_busy_response = """503: @@ -6730,7 +6734,7 @@ api_server_busy_response = """503: msg: Server is busy; please try again later. type: service_unavailable""" -class NotImplementedErrorSchema(Schema): +class NotImplementedErrorSchema(KoboldSchema): detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) api_not_implemented_response = """501: @@ -6743,7 +6747,7 @@ api_not_implemented_response = """501: msg: API generation is not supported in read-only mode; please load a model and then try again. type: not_implemented""" -class SamplerSettingsSchema(Schema): +class SamplerSettingsSchema(KoboldSchema): rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."}) rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."}) rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."}) @@ -6771,10 +6775,10 @@ class GenerationInputSchema(SamplerSettingsSchema): disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."}) frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."}) -class GenerationResultSchema(Schema): +class GenerationResultSchema(KoboldSchema): text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."}) -class GenerationOutputSchema(Schema): +class GenerationOutputSchema(KoboldSchema): results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."}) def _generate_text(body: GenerationInputSchema):