mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Upload basic API with /generate POST endpoint
This commit is contained in:
		
							
								
								
									
										468
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										468
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -36,8 +36,9 @@ import itertools | |||||||
| import bisect | import bisect | ||||||
| import functools | import functools | ||||||
| import traceback | import traceback | ||||||
|  | import inspect | ||||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||||
| from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List | from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type | ||||||
|  |  | ||||||
| import requests | import requests | ||||||
| import html | import html | ||||||
| @@ -352,6 +353,10 @@ class vars: | |||||||
|     use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != ""  # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU |     use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != ""  # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU | ||||||
|     revision    = None |     revision    = None | ||||||
|     output_streaming = False |     output_streaming = False | ||||||
|  |     standalone = False | ||||||
|  |     disable_set_aibusy = False | ||||||
|  |     disable_input_formatting = False | ||||||
|  |     disable_output_formatting = False | ||||||
|     token_stream_queue = [] # Queue for the token streaming |     token_stream_queue = [] # Queue for the token streaming | ||||||
|  |  | ||||||
| utils.vars = vars | utils.vars = vars | ||||||
| @@ -372,9 +377,11 @@ log.setLevel(logging.ERROR) | |||||||
|  |  | ||||||
| # Start flask & SocketIO | # Start flask & SocketIO | ||||||
| print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="") | print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="") | ||||||
| from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session | from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session, jsonify, abort | ||||||
| from flask_socketio import SocketIO, emit | from flask_socketio import SocketIO | ||||||
|  | from flask_socketio import emit as _emit | ||||||
| from flask_session import Session | from flask_session import Session | ||||||
|  | from werkzeug.exceptions import HTTPException, ServiceUnavailable | ||||||
| import secrets | import secrets | ||||||
| app = Flask(__name__, root_path=os.getcwd()) | app = Flask(__name__, root_path=os.getcwd()) | ||||||
| app.secret_key = secrets.token_hex() | app.secret_key = secrets.token_hex() | ||||||
| @@ -384,6 +391,144 @@ Session(app) | |||||||
| socketio = SocketIO(app, async_method="eventlet") | socketio = SocketIO(app, async_method="eventlet") | ||||||
| print("{0}OK!{1}".format(colors.GREEN, colors.END)) | print("{0}OK!{1}".format(colors.GREEN, colors.END)) | ||||||
|  |  | ||||||
|  | def emit(*args, **kwargs): | ||||||
|  |     try: | ||||||
|  |         return _emit(*args, **kwargs) | ||||||
|  |     except AttributeError: | ||||||
|  |         return socketio.emit(*args, **kwargs) | ||||||
|  |  | ||||||
|  | # marshmallow/apispec setup | ||||||
|  | from apispec import APISpec | ||||||
|  | from apispec.ext.marshmallow import MarshmallowPlugin | ||||||
|  | from apispec.ext.marshmallow.field_converter import make_min_max_attributes | ||||||
|  | from apispec_webframeworks.flask import FlaskPlugin | ||||||
|  | from marshmallow import Schema, fields, validate | ||||||
|  | from marshmallow.exceptions import ValidationError | ||||||
|  |  | ||||||
|  | def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict: | ||||||
|  |     # Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False | ||||||
|  |     attributes = {} | ||||||
|  |     min_list = [validator.min for validator in validators if validator.min is not None] | ||||||
|  |     max_list = [validator.max for validator in validators if validator.max is not None] | ||||||
|  |     min_inclusive_list = [getattr(validator, "min_inclusive", True) for validator in validators if validator.min is not None] | ||||||
|  |     max_inclusive_list = [getattr(validator, "max_inclusive", True) for validator in validators if validator.max is not None] | ||||||
|  |     if min_list: | ||||||
|  |         if min_attr == "minimum" and not min_inclusive_list[max(range(len(min_list)), key=min_list.__getitem__)]: | ||||||
|  |             min_attr = "exclusiveMinimum" | ||||||
|  |         attributes[min_attr] = max(min_list) | ||||||
|  |     if max_list: | ||||||
|  |         if min_attr == "maximum" and not max_inclusive_list[min(range(len(max_list)), key=max_list.__getitem__)]: | ||||||
|  |             min_attr = "exclusiveMaximum" | ||||||
|  |         attributes[max_attr] = min(max_list) | ||||||
|  |     return attributes | ||||||
|  | make_min_max_attributes.__code__ = new_make_min_max_attributes.__code__ | ||||||
|  |  | ||||||
|  | def api_format_docstring(f): | ||||||
|  |     f.__doc__ = eval('f"""{}"""'.format(f.__doc__)) | ||||||
|  |     return f | ||||||
|  |  | ||||||
|  | def api_catch_out_of_memory_errors(f): | ||||||
|  |     @functools.wraps(f) | ||||||
|  |     def decorated(*args, **kwargs): | ||||||
|  |         try: | ||||||
|  |             return f(*args, **kwargs) | ||||||
|  |         except Exception as e: | ||||||
|  |             if any (s in traceback.format_exc().lower() for s in ("out of memory", "not enough memory")): | ||||||
|  |                 for line in reversed(traceback.format_exc().split("\n")): | ||||||
|  |                     if any(s in line.lower() for s in ("out of memory", "not enough memory")) and line.count(":"): | ||||||
|  |                         line = line.split(":", 1)[1] | ||||||
|  |                         line = re.sub(r"\[.+?\] +data\.", "", line).strip() | ||||||
|  |                         raise KoboldOutOfMemoryError("KoboldAI ran out of memory: " + line, type="out_of_memory.gpu.cuda" if "cuda out of memory" in line.lower() else "out_of_memory.gpu.hip" if "hip out of memory" in line.lower() else "out_of_memory.tpu.hbm" if "memory space hbm" in line.lower() else "out_of_memory.cpu.default_memory_allocator" if "defaultmemoryallocator" in line.lower() else "out_of_memory.unknown.unknown") | ||||||
|  |                 raise KoboldOutOfMemoryError(type="out_of_memory.unknown.unknown") | ||||||
|  |             raise e | ||||||
|  |     return decorated | ||||||
|  |  | ||||||
|  | def api_schema_wrap(f): | ||||||
|  |     input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation | ||||||
|  |     assert inspect.isclass(input_schema) and issubclass(input_schema, Schema) | ||||||
|  |     f = api_format_docstring(f) | ||||||
|  |     f = api_catch_out_of_memory_errors(f) | ||||||
|  |     @functools.wraps(f) | ||||||
|  |     def decorated(*args, **Kwargs): | ||||||
|  |         body = request.get_json() | ||||||
|  |         schema = input_schema.from_dict(input_schema().load(body)) | ||||||
|  |         response = f(schema) | ||||||
|  |         if not isinstance(response, Response): | ||||||
|  |             response = jsonify(response) | ||||||
|  |         return response | ||||||
|  |     return decorated | ||||||
|  |  | ||||||
|  | @app.errorhandler(HTTPException) | ||||||
|  | def handler(e): | ||||||
|  |     return jsonify(detail={"type": "generic.error_" + str(e.code), "msg": str(e)}), e.code | ||||||
|  |  | ||||||
|  | class KoboldOutOfMemoryError(HTTPException): | ||||||
|  |     code = 507 | ||||||
|  |     description = "KoboldAI ran out of memory." | ||||||
|  |     type = "out_of_memory.unknown" | ||||||
|  |     def __init__(self, *args, type=None, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         if type is not None: | ||||||
|  |             self.type = type | ||||||
|  | @app.errorhandler(KoboldOutOfMemoryError) | ||||||
|  | def handler(e): | ||||||
|  |     return jsonify(detail={"type": e.type, "msg": e.description}), e.code | ||||||
|  |  | ||||||
|  | @app.errorhandler(ValidationError) | ||||||
|  | def handler(e): | ||||||
|  |     return jsonify(detail=e.messages), 422 | ||||||
|  |  | ||||||
|  | @app.errorhandler(NotImplementedError) | ||||||
|  | def handler(e): | ||||||
|  |     return jsonify(detail={"type": "not_implemented", "msg": str(e).strip()}), 501 | ||||||
|  |  | ||||||
|  | class KoboldAPISpec(APISpec): | ||||||
|  |     class KoboldFlaskPlugin(FlaskPlugin): | ||||||
|  |         def __init__(self, api: "KoboldAPISpec", *args, **kwargs): | ||||||
|  |             self._kobold_api_spec = api | ||||||
|  |             super().__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |         def path_helper(self, *args, **kwargs): | ||||||
|  |             return super().path_helper(*args, **kwargs)[len(self._kobold_api_spec._prefixes[0]):] | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, title: str = "KoboldAI API", openapi_version: str = "3.0.3", prefixes: List[str] = None, **kwargs): | ||||||
|  |         plugins = [KoboldAPISpec.KoboldFlaskPlugin(self), MarshmallowPlugin()] | ||||||
|  |         self._prefixes = prefixes if prefixes is not None else [""] | ||||||
|  |         super().__init__(*args, title=title, openapi_version=openapi_version, plugins=plugins, servers=[{"url": self._prefixes[0]}], **kwargs) | ||||||
|  |         for prefix in self._prefixes: | ||||||
|  |             app.route(prefix + "/docs", endpoint="~KoboldAPISpec~" + prefix + "/docs")(lambda: render_template("swagger-ui.html", url=self._prefixes[0] + "/openapi.json")) | ||||||
|  |             app.route(prefix + "/openapi.json", endpoint="~KoboldAPISpec~" + prefix + "/openapi.json")(lambda: jsonify(self.to_dict())) | ||||||
|  |  | ||||||
|  |     def route(self, rule: str, methods=["GET"], **kwargs): | ||||||
|  |         __F = TypeVar("__F", bound=Callable[..., Any]) | ||||||
|  |         def new_decorator(f: __F) -> __F: | ||||||
|  |             for prefix in self._prefixes: | ||||||
|  |                 f = app.route(prefix + rule, methods=methods, **kwargs)(f) | ||||||
|  |             with app.test_request_context(): | ||||||
|  |                 self.path(view=f, **kwargs) | ||||||
|  |             return f | ||||||
|  |         return new_decorator | ||||||
|  |  | ||||||
|  |     def get(self, rule: str, **kwargs): | ||||||
|  |         return self.route(rule, methods=["GET"], **kwargs) | ||||||
|  |      | ||||||
|  |     def post(self, rule: str, **kwargs): | ||||||
|  |         return self.route(rule, methods=["POST"], **kwargs) | ||||||
|  |      | ||||||
|  |     def put(self, rule: str, **kwargs): | ||||||
|  |         return self.route(rule, methods=["PUT"], **kwargs) | ||||||
|  |      | ||||||
|  |     def patch(self, rule: str, **kwargs): | ||||||
|  |         return self.route(rule, methods=["PATCH"], **kwargs) | ||||||
|  |      | ||||||
|  |     def delete(self, rule: str, **kwargs): | ||||||
|  |         return self.route(rule, methods=["DELETE"], **kwargs) | ||||||
|  |  | ||||||
|  | api_v1 = KoboldAPISpec( | ||||||
|  |     version="1.0.0", | ||||||
|  |     prefixes=["/api/v1", "/api/latest"], | ||||||
|  | ) | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| # Function to get model selection at startup | # Function to get model selection at startup | ||||||
| #==================================================================# | #==================================================================# | ||||||
| @@ -1492,6 +1637,9 @@ def patch_transformers(): | |||||||
|             self.regeneration_required = False |             self.regeneration_required = False | ||||||
|             self.halt = False |             self.halt = False | ||||||
|  |  | ||||||
|  |             if(vars.standalone): | ||||||
|  |                 return scores | ||||||
|  |  | ||||||
|             scores_shape = scores.shape |             scores_shape = scores.shape | ||||||
|             scores_list = scores.tolist() |             scores_list = scores.tolist() | ||||||
|             vars.lua_koboldbridge.logits = vars.lua_state.table() |             vars.lua_koboldbridge.logits = vars.lua_state.table() | ||||||
| @@ -1595,12 +1743,14 @@ def patch_transformers(): | |||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) -> bool: |         ) -> bool: | ||||||
|             vars.generated_tkns += 1 |             vars.generated_tkns += 1 | ||||||
|             if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols): |             if(not vars.standalone and vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols): | ||||||
|                 raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})") |                 raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})") | ||||||
|             if(vars.abort or vars.generated_tkns >= vars.genamt): |             if(vars.abort or vars.generated_tkns >= vars.genamt): | ||||||
|                 self.regeneration_required = False |                 self.regeneration_required = False | ||||||
|                 self.halt = False |                 self.halt = False | ||||||
|                 return True |                 return True | ||||||
|  |             if(vars.standalone): | ||||||
|  |                 return False | ||||||
|  |  | ||||||
|             assert input_ids.ndim == 2 |             assert input_ids.ndim == 2 | ||||||
|             assert len(self.excluded_world_info) == input_ids.shape[0] |             assert len(self.excluded_world_info) == input_ids.shape[0] | ||||||
| @@ -3767,6 +3917,97 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, | |||||||
|                 emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) |                 emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True) | ||||||
|                 break |                 break | ||||||
|  |  | ||||||
|  | def apiactionsubmit_generate(txt, minimum, maximum): | ||||||
|  |     vars.generated_tkns = 0 | ||||||
|  |  | ||||||
|  |     if not vars.quiet: | ||||||
|  |         print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END)) | ||||||
|  |  | ||||||
|  |     # Clear CUDA cache if using GPU | ||||||
|  |     if(vars.hascuda and (vars.usegpu or vars.breakmodel)): | ||||||
|  |         gc.collect() | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |     # Submit input text to generator | ||||||
|  |     _genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, set()) | ||||||
|  |  | ||||||
|  |     genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout] | ||||||
|  |  | ||||||
|  |     # Clear CUDA cache again if using GPU | ||||||
|  |     if(vars.hascuda and (vars.usegpu or vars.breakmodel)): | ||||||
|  |         del _genout | ||||||
|  |         gc.collect() | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |  | ||||||
|  |     return genout | ||||||
|  |  | ||||||
|  | def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum): | ||||||
|  |     vars.generated_tkns = 0 | ||||||
|  |  | ||||||
|  |     if(vars.full_determinism): | ||||||
|  |         tpu_mtj_backend.set_rng_seed(vars.seed) | ||||||
|  |  | ||||||
|  |     if not vars.quiet: | ||||||
|  |         print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END)) | ||||||
|  |  | ||||||
|  |     vars._actions = vars.actions | ||||||
|  |     vars._prompt = vars.prompt | ||||||
|  |     if(vars.dynamicscan): | ||||||
|  |         vars._actions = vars._actions.copy() | ||||||
|  |  | ||||||
|  |     # Submit input text to generator | ||||||
|  |     soft_tokens = tpumtjgetsofttokens() | ||||||
|  |     genout = tpool.execute( | ||||||
|  |         tpu_mtj_backend.infer_static, | ||||||
|  |         np.uint32(txt), | ||||||
|  |         gen_len = maximum-minimum+1, | ||||||
|  |         temp=vars.temp, | ||||||
|  |         top_p=vars.top_p, | ||||||
|  |         top_k=vars.top_k, | ||||||
|  |         tfs=vars.tfs, | ||||||
|  |         typical=vars.typical, | ||||||
|  |         top_a=vars.top_a, | ||||||
|  |         numseqs=vars.numseqs, | ||||||
|  |         repetition_penalty=vars.rep_pen, | ||||||
|  |         rpslope=vars.rep_pen_slope, | ||||||
|  |         rprange=vars.rep_pen_range, | ||||||
|  |         soft_embeddings=vars.sp, | ||||||
|  |         soft_tokens=soft_tokens, | ||||||
|  |         sampler_order=vars.sampler_order, | ||||||
|  |     ) | ||||||
|  |     genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(txt))) for txt in genout] | ||||||
|  |  | ||||||
|  |     return genout | ||||||
|  |  | ||||||
|  | def apiactionsubmit(data, use_memory=False): | ||||||
|  |     if(vars.model == "Colab"): | ||||||
|  |         raise NotImplementedError("API generation is not supported in old Colab API mode.") | ||||||
|  |     elif(vars.model == "OAI"): | ||||||
|  |         raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.") | ||||||
|  |     elif(vars.model == "ReadOnly"): | ||||||
|  |         raise NotImplementedError("API generation is not supported in read-only mode; please load a model and then try again.") | ||||||
|  |  | ||||||
|  |     if(vars.memory != "" and vars.memory[-1] != "\n"): | ||||||
|  |         mem = vars.memory + "\n" | ||||||
|  |     else: | ||||||
|  |         mem = vars.memory | ||||||
|  |     tokens = [] | ||||||
|  |     if(use_memory): | ||||||
|  |         tokens += tokenizer.encode(utils.encodenewlines(mem))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):] | ||||||
|  |     tokens += tokenizer.encode(utils.encodenewlines(data))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):] | ||||||
|  |     tokens = tokenizer._koboldai_header + tokens | ||||||
|  |     minimum = len(tokens) + 1 | ||||||
|  |     maximum = len(tokens) + vars.genamt | ||||||
|  |  | ||||||
|  |     if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]): | ||||||
|  |         genout = apiactionsubmit_generate(tokens, minimum, maximum) | ||||||
|  |     elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")): | ||||||
|  |         genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum) | ||||||
|  |  | ||||||
|  |     genout = [applyoutputformatting(txt) for txt in genout] | ||||||
|  |  | ||||||
|  |     return genout | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| #   | #   | ||||||
| #==================================================================# | #==================================================================# | ||||||
| @@ -4727,6 +4968,8 @@ def refresh_settings(): | |||||||
| #  Sets the logical and display states for the AI Busy condition | #  Sets the logical and display states for the AI Busy condition | ||||||
| #==================================================================# | #==================================================================# | ||||||
| def set_aibusy(state): | def set_aibusy(state): | ||||||
|  |     if(vars.disable_set_aibusy): | ||||||
|  |         return | ||||||
|     if(state): |     if(state): | ||||||
|         vars.aibusy = True |         vars.aibusy = True | ||||||
|         emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'}, broadcast=True) |         emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'}, broadcast=True) | ||||||
| @@ -6420,6 +6663,223 @@ def get_files_folders(starting_folder): | |||||||
|         socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True) |         socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BasicErrorSchema(Schema): | ||||||
|  |     msg: str = fields.String(required=True) | ||||||
|  |     type: str = fields.String(required=True) | ||||||
|  |  | ||||||
|  | class OutOfMemoryErrorSchema(Schema): | ||||||
|  |     detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) | ||||||
|  |  | ||||||
|  | api_out_of_memory_response = """507: | ||||||
|  |           description: Out of memory | ||||||
|  |           content: | ||||||
|  |             application/json: | ||||||
|  |               schema: OutOfMemoryErrorSchema | ||||||
|  |               examples: | ||||||
|  |                 gpu.cuda: | ||||||
|  |                   value: | ||||||
|  |                     detail: | ||||||
|  |                       msg: "KoboldAI ran out of memory: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)" | ||||||
|  |                       type: out_of_memory.gpu.cuda | ||||||
|  |                 gpu.hip: | ||||||
|  |                   value: | ||||||
|  |                     detail: | ||||||
|  |                       msg: "KoboldAI ran out of memory: HIP out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)" | ||||||
|  |                       type: out_of_memory.gpu.hip | ||||||
|  |                 tpu.hbm: | ||||||
|  |                   value: | ||||||
|  |                     detail: | ||||||
|  |                       msg: "KoboldAI ran out of memory: Compilation failed: Compilation failure: Ran out of memory in memory space hbm. Used 8.83G of 8.00G hbm. Exceeded hbm capacity by 848.88M." | ||||||
|  |                       type: out_of_memory.tpu.hbm | ||||||
|  |                 cpu.default_cpu_allocator: | ||||||
|  |                   value: | ||||||
|  |                     detail: | ||||||
|  |                       msg: "KoboldAI ran out of memory: DefaultCPUAllocator: not enough memory: you tried to allocate 209715200 bytes." | ||||||
|  |                       type: out_of_memory.cpu.default_cpu_allocator | ||||||
|  |                 unknown.unknown: | ||||||
|  |                   value: | ||||||
|  |                     detail: | ||||||
|  |                       msg: "KoboldAI ran out of memory." | ||||||
|  |                       type: out_of_memory.unknown.unknown""" | ||||||
|  |  | ||||||
|  | class ValidationErrorSchema(Schema): | ||||||
|  |     detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String()), required=True) | ||||||
|  |  | ||||||
|  | api_validation_error_response = """422: | ||||||
|  |           description: Validation error | ||||||
|  |           content: | ||||||
|  |             application/json: | ||||||
|  |               schema: ValidationErrorSchema""" | ||||||
|  |  | ||||||
|  | class ServerBusyErrorSchema(Schema): | ||||||
|  |     detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) | ||||||
|  |  | ||||||
|  | api_server_busy_response = """503: | ||||||
|  |           description: Server is busy | ||||||
|  |           content: | ||||||
|  |             application/json: | ||||||
|  |               schema: ServerBusyErrorSchema | ||||||
|  |               example: | ||||||
|  |                 detail: | ||||||
|  |                   msg: Server is busy; please try again later. | ||||||
|  |                   type: service_unavailable""" | ||||||
|  |  | ||||||
|  | class NotImplementedErrorSchema(Schema): | ||||||
|  |     detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True) | ||||||
|  |  | ||||||
|  | api_not_implemented_response = """501: | ||||||
|  |           description: Not implemented | ||||||
|  |           content: | ||||||
|  |             application/json: | ||||||
|  |               schema: NotImplementedErrorSchema | ||||||
|  |               example: | ||||||
|  |                 detail: | ||||||
|  |                   msg: API generation is not supported in read-only mode; please load a model and then try again. | ||||||
|  |                   type: not_implemented""" | ||||||
|  |  | ||||||
|  | class SamplerSettingsSchema(Schema): | ||||||
|  |     rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."}) | ||||||
|  |     rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."}) | ||||||
|  |     rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."}) | ||||||
|  |     top_k: Optional[int] = fields.Int(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."}) | ||||||
|  |     top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."}) | ||||||
|  |     top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."}) | ||||||
|  |     tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."}) | ||||||
|  |     typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."}) | ||||||
|  |     temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."}) | ||||||
|  |  | ||||||
|  | class GenerationInputSchema(SamplerSettingsSchema): | ||||||
|  |     prompt: str = fields.String(required=True, metadata={"description": "This is the submission."}) | ||||||
|  |     use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."}) | ||||||
|  |     use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) | ||||||
|  |     use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) | ||||||
|  |     use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."}) | ||||||
|  |     soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}) | ||||||
|  |     max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."}) | ||||||
|  |     n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."}) | ||||||
|  |     disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."}) | ||||||
|  |     frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."}) | ||||||
|  |     frmtrmblln: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines in the output with one newline."}) | ||||||
|  |     frmtrmspch: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes `#/@%{}+=~|\^<>` from the output."}) | ||||||
|  |     singleline: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes everything after the first line of the output, including the newline."}) | ||||||
|  |     disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."}) | ||||||
|  |     frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."}) | ||||||
|  |  | ||||||
|  | class GenerationResultSchema(Schema): | ||||||
|  |     text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."}) | ||||||
|  |  | ||||||
|  | class GenerationOutputSchema(Schema): | ||||||
|  |     results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."}) | ||||||
|  |  | ||||||
|  | def _generate_text(body: GenerationInputSchema): | ||||||
|  |     if vars.aibusy or vars.genseqs: | ||||||
|  |         abort(Response(json.dumps({"detail": { | ||||||
|  |             "type": "service_unavailable", | ||||||
|  |             "msg": "Server is busy; please try again later.", | ||||||
|  |         }}), mimetype="application/json", status=503)) | ||||||
|  |     if body.use_story: | ||||||
|  |         raise NotImplementedError("use_story is not currently supported.") | ||||||
|  |     if body.use_world_info: | ||||||
|  |         raise NotImplementedError("use_world_info is not currently supported.") | ||||||
|  |     if body.use_userscripts: | ||||||
|  |         raise NotImplementedError("use_userscripts is not currently supported.") | ||||||
|  |     mapping = { | ||||||
|  |         "rep_pen": (vars, "rep_pen"), | ||||||
|  |         "rep_pen_range": (vars, "rep_pen_range"), | ||||||
|  |         "rep_pen_slope": (vars, "rep_pen_slope"), | ||||||
|  |         "top_k": (vars, "top_k"), | ||||||
|  |         "top_a": (vars, "top_a"), | ||||||
|  |         "top_p": (vars, "top_p"), | ||||||
|  |         "tfs": (vars, "tfs"), | ||||||
|  |         "typical": (vars, "typical"), | ||||||
|  |         "temperature": (vars, "temp"), | ||||||
|  |         "frmtadnsp": (vars.formatoptns, "@frmtadnsp"), | ||||||
|  |         "frmttriminc": (vars.formatoptns, "@frmttriminc"), | ||||||
|  |         "frmtrmblln": (vars.formatoptns, "@frmtrmblln"), | ||||||
|  |         "frmtrmspch": (vars.formatoptns, "@frmtrmspch"), | ||||||
|  |         "singleline": (vars.formatoptns, "@singleline"), | ||||||
|  |         "disable_input_formatting": (vars, "disable_input_formatting"), | ||||||
|  |         "disable_output_formatting": (vars, "disable_output_formatting"), | ||||||
|  |         "max_length": (vars, "genamt"), | ||||||
|  |         "n": (vars, "numseqs"), | ||||||
|  |     } | ||||||
|  |     saved_settings = {} | ||||||
|  |     set_aibusy(1) | ||||||
|  |     disable_set_aibusy = vars.disable_set_aibusy | ||||||
|  |     vars.disable_set_aibusy = True | ||||||
|  |     _standalone = vars.standalone | ||||||
|  |     vars.standalone = True | ||||||
|  |     for key, entry in mapping.items(): | ||||||
|  |         if getattr(body, key, None) is not None: | ||||||
|  |             if entry[1].startswith("@"): | ||||||
|  |                 saved_settings[key] = entry[0][entry[1][1:]] | ||||||
|  |                 entry[0][entry[1][1:]] = getattr(body, key) | ||||||
|  |             else: | ||||||
|  |                 saved_settings[key] = getattr(entry[0], entry[1]) | ||||||
|  |                 setattr(entry[0], entry[1], getattr(body, key)) | ||||||
|  |     try: | ||||||
|  |         if getattr(body, "soft_prompt", None) is not None: | ||||||
|  |             if any(q in body.soft_prompt for q in ("/", "\\")): | ||||||
|  |                 raise RuntimeError | ||||||
|  |             old_spfilename = vars.spfilename | ||||||
|  |             spRequest(body.soft_prompt) | ||||||
|  |         genout = apiactionsubmit(body.prompt, use_memory=body.use_memory) | ||||||
|  |         output = {"results": [{"text": txt} for txt in genout]} | ||||||
|  |     finally: | ||||||
|  |         for key in saved_settings: | ||||||
|  |             entry = mapping[key] | ||||||
|  |             if getattr(body, key, None) is not None: | ||||||
|  |                 if entry[1].startswith("@"): | ||||||
|  |                     if entry[0][entry[1][1:]] == getattr(body, key): | ||||||
|  |                         entry[0][entry[1][1:]] = saved_settings[key] | ||||||
|  |                 else: | ||||||
|  |                     if getattr(entry[0], entry[1]) == getattr(body, key): | ||||||
|  |                         setattr(entry[0], entry[1], saved_settings[key]) | ||||||
|  |         vars.disable_set_aibusy = disable_set_aibusy | ||||||
|  |         vars.standalone = _standalone | ||||||
|  |         if getattr(body, "soft_prompt", None) is not None: | ||||||
|  |             spRequest(old_spfilename) | ||||||
|  |         set_aibusy(0) | ||||||
|  |     return output | ||||||
|  |  | ||||||
|  | @api_v1.post("/generate") | ||||||
|  | @api_schema_wrap | ||||||
|  | def post_completion_standalone(body: GenerationInputSchema): | ||||||
|  |     r"""Generate text | ||||||
|  |     --- | ||||||
|  |     post: | ||||||
|  |       description: |-2 | ||||||
|  |         Generates text given a submission, sampler settings, soft prompt and number of return sequences. | ||||||
|  |  | ||||||
|  |         Unless otherwise specified, optional values default to the values in the KoboldAI GUI. | ||||||
|  |       requestBody: | ||||||
|  |         required: true | ||||||
|  |         content: | ||||||
|  |           application/json: | ||||||
|  |             schema: GenerationInputSchema | ||||||
|  |             example: | ||||||
|  |               prompt: |-2 | ||||||
|  |                 Explosions of suspicious origin occur at AMNAT satellite-receiver stations from Turkey to Labrador as three high-level Canadian defense ministers vanish and then a couple of days later are photographed at a Volgograd bistro hoisting shots of Stolichnaya with Slavic bimbos on their knee. | ||||||
|  |               top_p: 0.9 | ||||||
|  |               temperature: 0.5 | ||||||
|  |       responses: | ||||||
|  |         200: | ||||||
|  |           description: Successful request | ||||||
|  |           content: | ||||||
|  |             application/json: | ||||||
|  |               schema: GenerationOutputSchema | ||||||
|  |               example: | ||||||
|  |                 results: | ||||||
|  |                   - text: |-2 | ||||||
|  |                        It is later established that all of the cabinet members have died of old age. | ||||||
|  |                       MEGAMATRIX becomes involved in the growing number of mass abductions and kidnappings. Many disappearances occur along highways in western Canada, usually when traffic has come to a standstill because of a stalled truck or snowstorm. One or two abducted individuals will be released within a day or so but never | ||||||
|  |         {api_validation_error_response} | ||||||
|  |         {api_not_implemented_response} | ||||||
|  |         {api_server_busy_response} | ||||||
|  |         {api_out_of_memory_response} | ||||||
|  |     """ | ||||||
|  |     return _generate_text(body) | ||||||
|  |  | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| #  Final startup commands to launch Flask app | #  Final startup commands to launch Flask app | ||||||
|   | |||||||
| @@ -16,6 +16,8 @@ dependencies: | |||||||
|   - bleach=4.1.0 |   - bleach=4.1.0 | ||||||
|   - pip |   - pip | ||||||
|   - git=2.35.1 |   - git=2.35.1 | ||||||
|  |   - marshmallow>=3.13 | ||||||
|  |   - apispec-webframeworks | ||||||
|   - pip: |   - pip: | ||||||
|     - git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3-rp-b |     - git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3-rp-b | ||||||
|     - flask-cloudflared |     - flask-cloudflared | ||||||
|   | |||||||
| @@ -17,9 +17,11 @@ dependencies: | |||||||
|   - git=2.35.1 |   - git=2.35.1 | ||||||
|   - sentencepiece |   - sentencepiece | ||||||
|   - protobuf |   - protobuf | ||||||
|  |   - marshmallow>=3.13 | ||||||
|  |   - apispec-webframeworks | ||||||
|   - pip: |   - pip: | ||||||
|     - flask-cloudflared |     - flask-cloudflared | ||||||
|     - flask-ngrok |     - flask-ngrok | ||||||
|     - lupa==1.10 |     - lupa==1.10 | ||||||
|     - transformers>=4.20.1 |     - transformers>=4.20.1 | ||||||
|     - accelerate |     - accelerate | ||||||
|   | |||||||
| @@ -12,6 +12,8 @@ dependencies: | |||||||
|   - bleach=4.1.0 |   - bleach=4.1.0 | ||||||
|   - pip |   - pip | ||||||
|   - git=2.35.1 |   - git=2.35.1 | ||||||
|  |   - marshmallow>=3.13 | ||||||
|  |   - apispec-webframeworks | ||||||
|   - pip: |   - pip: | ||||||
|     - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html |     - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html | ||||||
|     - torch |     - torch | ||||||
|   | |||||||
| @@ -14,6 +14,8 @@ dependencies: | |||||||
|   - git=2.35.1 |   - git=2.35.1 | ||||||
|   - sentencepiece |   - sentencepiece | ||||||
|   - protobuf |   - protobuf | ||||||
|  |   - marshmallow>=3.13 | ||||||
|  |   - apispec-webframeworks | ||||||
|   - pip: |   - pip: | ||||||
|     - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html |     - --find-links https://download.pytorch.org/whl/rocm4.2/torch_stable.html | ||||||
|     - torch==1.10.* |     - torch==1.10.* | ||||||
|   | |||||||
| @@ -12,4 +12,6 @@ bleach==4.1.0 | |||||||
| sentencepiece | sentencepiece | ||||||
| protobuf | protobuf | ||||||
| accelerate | accelerate | ||||||
| flask-session | flask-session | ||||||
|  | marshmallow>=3.13 | ||||||
|  | apispec-webframeworks | ||||||
|   | |||||||
| @@ -17,4 +17,6 @@ eventlet | |||||||
| lupa==1.10 | lupa==1.10 | ||||||
| markdown | markdown | ||||||
| bleach==4.1.0 | bleach==4.1.0 | ||||||
| flask-session | flask-session | ||||||
|  | marshmallow>=3.13 | ||||||
|  | apispec-webframeworks | ||||||
|   | |||||||
							
								
								
									
										202
									
								
								static/swagger-ui/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								static/swagger-ui/LICENSE
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,202 @@ | |||||||
|  |  | ||||||
|  |                                  Apache License | ||||||
|  |                            Version 2.0, January 2004 | ||||||
|  |                         http://www.apache.org/licenses/ | ||||||
|  |  | ||||||
|  |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||||
|  |  | ||||||
|  |    1. Definitions. | ||||||
|  |  | ||||||
|  |       "License" shall mean the terms and conditions for use, reproduction, | ||||||
|  |       and distribution as defined by Sections 1 through 9 of this document. | ||||||
|  |  | ||||||
|  |       "Licensor" shall mean the copyright owner or entity authorized by | ||||||
|  |       the copyright owner that is granting the License. | ||||||
|  |  | ||||||
|  |       "Legal Entity" shall mean the union of the acting entity and all | ||||||
|  |       other entities that control, are controlled by, or are under common | ||||||
|  |       control with that entity. For the purposes of this definition, | ||||||
|  |       "control" means (i) the power, direct or indirect, to cause the | ||||||
|  |       direction or management of such entity, whether by contract or | ||||||
|  |       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||||
|  |       outstanding shares, or (iii) beneficial ownership of such entity. | ||||||
|  |  | ||||||
|  |       "You" (or "Your") shall mean an individual or Legal Entity | ||||||
|  |       exercising permissions granted by this License. | ||||||
|  |  | ||||||
|  |       "Source" form shall mean the preferred form for making modifications, | ||||||
|  |       including but not limited to software source code, documentation | ||||||
|  |       source, and configuration files. | ||||||
|  |  | ||||||
|  |       "Object" form shall mean any form resulting from mechanical | ||||||
|  |       transformation or translation of a Source form, including but | ||||||
|  |       not limited to compiled object code, generated documentation, | ||||||
|  |       and conversions to other media types. | ||||||
|  |  | ||||||
|  |       "Work" shall mean the work of authorship, whether in Source or | ||||||
|  |       Object form, made available under the License, as indicated by a | ||||||
|  |       copyright notice that is included in or attached to the work | ||||||
|  |       (an example is provided in the Appendix below). | ||||||
|  |  | ||||||
|  |       "Derivative Works" shall mean any work, whether in Source or Object | ||||||
|  |       form, that is based on (or derived from) the Work and for which the | ||||||
|  |       editorial revisions, annotations, elaborations, or other modifications | ||||||
|  |       represent, as a whole, an original work of authorship. For the purposes | ||||||
|  |       of this License, Derivative Works shall not include works that remain | ||||||
|  |       separable from, or merely link (or bind by name) to the interfaces of, | ||||||
|  |       the Work and Derivative Works thereof. | ||||||
|  |  | ||||||
|  |       "Contribution" shall mean any work of authorship, including | ||||||
|  |       the original version of the Work and any modifications or additions | ||||||
|  |       to that Work or Derivative Works thereof, that is intentionally | ||||||
|  |       submitted to Licensor for inclusion in the Work by the copyright owner | ||||||
|  |       or by an individual or Legal Entity authorized to submit on behalf of | ||||||
|  |       the copyright owner. For the purposes of this definition, "submitted" | ||||||
|  |       means any form of electronic, verbal, or written communication sent | ||||||
|  |       to the Licensor or its representatives, including but not limited to | ||||||
|  |       communication on electronic mailing lists, source code control systems, | ||||||
|  |       and issue tracking systems that are managed by, or on behalf of, the | ||||||
|  |       Licensor for the purpose of discussing and improving the Work, but | ||||||
|  |       excluding communication that is conspicuously marked or otherwise | ||||||
|  |       designated in writing by the copyright owner as "Not a Contribution." | ||||||
|  |  | ||||||
|  |       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||||
|  |       on behalf of whom a Contribution has been received by Licensor and | ||||||
|  |       subsequently incorporated within the Work. | ||||||
|  |  | ||||||
|  |    2. Grant of Copyright License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       copyright license to reproduce, prepare Derivative Works of, | ||||||
|  |       publicly display, publicly perform, sublicense, and distribute the | ||||||
|  |       Work and such Derivative Works in Source or Object form. | ||||||
|  |  | ||||||
|  |    3. Grant of Patent License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       (except as stated in this section) patent license to make, have made, | ||||||
|  |       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||||
|  |       where such license applies only to those patent claims licensable | ||||||
|  |       by such Contributor that are necessarily infringed by their | ||||||
|  |       Contribution(s) alone or by combination of their Contribution(s) | ||||||
|  |       with the Work to which such Contribution(s) was submitted. If You | ||||||
|  |       institute patent litigation against any entity (including a | ||||||
|  |       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||||
|  |       or a Contribution incorporated within the Work constitutes direct | ||||||
|  |       or contributory patent infringement, then any patent licenses | ||||||
|  |       granted to You under this License for that Work shall terminate | ||||||
|  |       as of the date such litigation is filed. | ||||||
|  |  | ||||||
|  |    4. Redistribution. You may reproduce and distribute copies of the | ||||||
|  |       Work or Derivative Works thereof in any medium, with or without | ||||||
|  |       modifications, and in Source or Object form, provided that You | ||||||
|  |       meet the following conditions: | ||||||
|  |  | ||||||
|  |       (a) You must give any other recipients of the Work or | ||||||
|  |           Derivative Works a copy of this License; and | ||||||
|  |  | ||||||
|  |       (b) You must cause any modified files to carry prominent notices | ||||||
|  |           stating that You changed the files; and | ||||||
|  |  | ||||||
|  |       (c) You must retain, in the Source form of any Derivative Works | ||||||
|  |           that You distribute, all copyright, patent, trademark, and | ||||||
|  |           attribution notices from the Source form of the Work, | ||||||
|  |           excluding those notices that do not pertain to any part of | ||||||
|  |           the Derivative Works; and | ||||||
|  |  | ||||||
|  |       (d) If the Work includes a "NOTICE" text file as part of its | ||||||
|  |           distribution, then any Derivative Works that You distribute must | ||||||
|  |           include a readable copy of the attribution notices contained | ||||||
|  |           within such NOTICE file, excluding those notices that do not | ||||||
|  |           pertain to any part of the Derivative Works, in at least one | ||||||
|  |           of the following places: within a NOTICE text file distributed | ||||||
|  |           as part of the Derivative Works; within the Source form or | ||||||
|  |           documentation, if provided along with the Derivative Works; or, | ||||||
|  |           within a display generated by the Derivative Works, if and | ||||||
|  |           wherever such third-party notices normally appear. The contents | ||||||
|  |           of the NOTICE file are for informational purposes only and | ||||||
|  |           do not modify the License. You may add Your own attribution | ||||||
|  |           notices within Derivative Works that You distribute, alongside | ||||||
|  |           or as an addendum to the NOTICE text from the Work, provided | ||||||
|  |           that such additional attribution notices cannot be construed | ||||||
|  |           as modifying the License. | ||||||
|  |  | ||||||
|  |       You may add Your own copyright statement to Your modifications and | ||||||
|  |       may provide additional or different license terms and conditions | ||||||
|  |       for use, reproduction, or distribution of Your modifications, or | ||||||
|  |       for any such Derivative Works as a whole, provided Your use, | ||||||
|  |       reproduction, and distribution of the Work otherwise complies with | ||||||
|  |       the conditions stated in this License. | ||||||
|  |  | ||||||
|  |    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||||
|  |       any Contribution intentionally submitted for inclusion in the Work | ||||||
|  |       by You to the Licensor shall be under the terms and conditions of | ||||||
|  |       this License, without any additional terms or conditions. | ||||||
|  |       Notwithstanding the above, nothing herein shall supersede or modify | ||||||
|  |       the terms of any separate license agreement you may have executed | ||||||
|  |       with Licensor regarding such Contributions. | ||||||
|  |  | ||||||
|  |    6. Trademarks. This License does not grant permission to use the trade | ||||||
|  |       names, trademarks, service marks, or product names of the Licensor, | ||||||
|  |       except as required for reasonable and customary use in describing the | ||||||
|  |       origin of the Work and reproducing the content of the NOTICE file. | ||||||
|  |  | ||||||
|  |    7. Disclaimer of Warranty. Unless required by applicable law or | ||||||
|  |       agreed to in writing, Licensor provides the Work (and each | ||||||
|  |       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||||
|  |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||||
|  |       implied, including, without limitation, any warranties or conditions | ||||||
|  |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||||
|  |       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||||
|  |       appropriateness of using or redistributing the Work and assume any | ||||||
|  |       risks associated with Your exercise of permissions under this License. | ||||||
|  |  | ||||||
|  |    8. Limitation of Liability. In no event and under no legal theory, | ||||||
|  |       whether in tort (including negligence), contract, or otherwise, | ||||||
|  |       unless required by applicable law (such as deliberate and grossly | ||||||
|  |       negligent acts) or agreed to in writing, shall any Contributor be | ||||||
|  |       liable to You for damages, including any direct, indirect, special, | ||||||
|  |       incidental, or consequential damages of any character arising as a | ||||||
|  |       result of this License or out of the use or inability to use the | ||||||
|  |       Work (including but not limited to damages for loss of goodwill, | ||||||
|  |       work stoppage, computer failure or malfunction, or any and all | ||||||
|  |       other commercial damages or losses), even if such Contributor | ||||||
|  |       has been advised of the possibility of such damages. | ||||||
|  |  | ||||||
|  |    9. Accepting Warranty or Additional Liability. While redistributing | ||||||
|  |       the Work or Derivative Works thereof, You may choose to offer, | ||||||
|  |       and charge a fee for, acceptance of support, warranty, indemnity, | ||||||
|  |       or other liability obligations and/or rights consistent with this | ||||||
|  |       License. However, in accepting such obligations, You may act only | ||||||
|  |       on Your own behalf and on Your sole responsibility, not on behalf | ||||||
|  |       of any other Contributor, and only if You agree to indemnify, | ||||||
|  |       defend, and hold each Contributor harmless for any liability | ||||||
|  |       incurred by, or claims asserted against, such Contributor by reason | ||||||
|  |       of your accepting any such warranty or additional liability. | ||||||
|  |  | ||||||
|  |    END OF TERMS AND CONDITIONS | ||||||
|  |  | ||||||
|  |    APPENDIX: How to apply the Apache License to your work. | ||||||
|  |  | ||||||
|  |       To apply the Apache License to your work, attach the following | ||||||
|  |       boilerplate notice, with the fields enclosed by brackets "[]" | ||||||
|  |       replaced with your own identifying information. (Don't include | ||||||
|  |       the brackets!)  The text should be enclosed in the appropriate | ||||||
|  |       comment syntax for the file format. We also recommend that a | ||||||
|  |       file or class name and description of purpose be included on the | ||||||
|  |       same "printed page" as the copyright notice for easier | ||||||
|  |       identification within third-party archives. | ||||||
|  |  | ||||||
|  |    Copyright [yyyy] [name of copyright owner] | ||||||
|  |  | ||||||
|  |    Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |    you may not use this file except in compliance with the License. | ||||||
|  |    You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |        http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  |    Unless required by applicable law or agreed to in writing, software | ||||||
|  |    distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |    See the License for the specific language governing permissions and | ||||||
|  |    limitations under the License. | ||||||
							
								
								
									
										853
									
								
								static/swagger-ui/SwaggerDark.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										853
									
								
								static/swagger-ui/SwaggerDark.css
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,853 @@ | |||||||
|  | /*! | ||||||
|  |  * MIT License | ||||||
|  |  *  | ||||||
|  |  * Copyright (c) 2020 Romans Pokrovskis | ||||||
|  |  *  | ||||||
|  |  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  |  * of this software and associated documentation files (the "Software"), to deal | ||||||
|  |  * in the Software without restriction, including without limitation the rights | ||||||
|  |  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  |  * copies of the Software, and to permit persons to whom the Software is | ||||||
|  |  * furnished to do so, subject to the following conditions: | ||||||
|  |  *  | ||||||
|  |  * The above copyright notice and this permission notice shall be included in all | ||||||
|  |  * copies or substantial portions of the Software. | ||||||
|  |  *  | ||||||
|  |  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  |  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  |  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  |  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  |  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  |  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  |  * SOFTWARE. | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | a { color: #8c8cfa; } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-track-piece { background-color: rgba(255, 255, 255, .2) !important; } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-track { background-color: rgba(255, 255, 255, .3) !important; } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-thumb { background-color: rgba(255, 255, 255, .5) !important; } | ||||||
|  |  | ||||||
|  | embed[type="application/pdf"] { filter: invert(90%); } | ||||||
|  |  | ||||||
|  | html { | ||||||
|  |     background: #1f1f1f !important; | ||||||
|  |     box-sizing: border-box; | ||||||
|  |     filter: contrast(100%) brightness(100%) saturate(100%); | ||||||
|  |     overflow-y: scroll; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | body { | ||||||
|  |     background: #1f1f1f; | ||||||
|  |     background-color: #1f1f1f; | ||||||
|  |     background-image: none !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | button, input, select, textarea { | ||||||
|  |     background-color: #1f1f1f; | ||||||
|  |     color: #bfbfbf; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | font, html { color: #bfbfbf; } | ||||||
|  |  | ||||||
|  | .swagger-ui, .swagger-ui section h3 { color: #b5bac9; } | ||||||
|  |  | ||||||
|  | .swagger-ui a { background-color: transparent; } | ||||||
|  |  | ||||||
|  | .swagger-ui mark { | ||||||
|  |     background-color: #664b00; | ||||||
|  |     color: #bfbfbf; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui legend { color: inherit; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug * { outline: #e6da99 solid 1px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-white * { outline: #fff solid 1px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-black * { outline: #bfbfbf solid 1px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-grid { background: url(data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAYAAADED76LAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAyhpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADw/eHBhY2tldCBiZWdpbj0i77u/IiBpZD0iVzVNME1wQ2VoaUh6cmVTek5UY3prYzlkIj8+IDx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IkFkb2JlIFhNUCBDb3JlIDUuNi1jMTExIDc5LjE1ODMyNSwgMjAxNS8wOS8xMC0wMToxMDoyMCAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0UmVmPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VSZWYjIiB4bWxuczp4bXA9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC8iIHhtcE1NOkRvY3VtZW50SUQ9InhtcC5kaWQ6MTRDOTY4N0U2N0VFMTFFNjg2MzZDQjkwNkQ4MjgwMEIiIHhtcE1NOkluc3RhbmNlSUQ9InhtcC5paWQ6MTRDOTY4N0Q2N0VFMTFFNjg2MzZDQjkwNkQ4MjgwMEIiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIENDIDIwMTUgKE1hY2ludG9zaCkiPiA8eG1wTU06RGVyaXZlZEZyb20gc3RSZWY6aW5zdGFuY2VJRD0ieG1wLmlpZDo3NjcyQkQ3NjY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIgc3RSZWY6ZG9jdW1lbnRJRD0ieG1wLmRpZDo3NjcyQkQ3NzY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIvPiA8L3JkZjpEZXNjcmlwdGlvbj4gPC9yZGY6UkRGPiA8L3g6eG1wbWV0YT4gPD94cGFja2V0IGVuZD0iciI/PsBS+GMAAAAjSURBVHjaYvz//z8DLsD4gcGXiYEAGBIKGBne//fFpwAgwAB98AaF2pjlUQAAAABJRU5ErkJggg==) 0 0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-grid-16 { background: url(data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAyhpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADw/eHBhY2tldCBiZWdpbj0i77u/IiBpZD0iVzVNME1wQ2VoaUh6cmVTek5UY3prYzlkIj8+IDx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IkFkb2JlIFhNUCBDb3JlIDUuNi1jMTExIDc5LjE1ODMyNSwgMjAxNS8wOS8xMC0wMToxMDoyMCAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0UmVmPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VSZWYjIiB4bWxuczp4bXA9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC8iIHhtcE1NOkRvY3VtZW50SUQ9InhtcC5kaWQ6ODYyRjhERDU2N0YyMTFFNjg2MzZDQjkwNkQ4MjgwMEIiIHhtcE1NOkluc3RhbmNlSUQ9InhtcC5paWQ6ODYyRjhERDQ2N0YyMTFFNjg2MzZDQjkwNkQ4MjgwMEIiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIENDIDIwMTUgKE1hY2ludG9zaCkiPiA8eG1wTU06RGVyaXZlZEZyb20gc3RSZWY6aW5zdGFuY2VJRD0ieG1wLmlpZDo3NjcyQkQ3QTY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIgc3RSZWY6ZG9jdW1lbnRJRD0ieG1wLmRpZDo3NjcyQkQ3QjY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIvPiA8L3JkZjpEZXNjcmlwdGlvbj4gPC9yZGY6UkRGPiA8L3g6eG1wbWV0YT4gPD94cGFja2V0IGVuZD0iciI/PvCS01IAAABMSURBVHjaYmR4/5+BFPBfAMFm/MBgx8RAGWCn1AAmSg34Q6kBDKMGMDCwICeMIemF/5QawEipAWwUhwEjMDvbAWlWkvVBwu8vQIABAEwBCph8U6c0AAAAAElFTkSuQmCC) 0 0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-grid-8-solid { background: url(data:image/jpeg;base64,/9j/4QAYRXhpZgAASUkqAAgAAAAAAAAAAAAAAP/sABFEdWNreQABAAQAAAAAAAD/4QMxaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLwA8P3hwYWNrZXQgYmVnaW49Iu+7vyIgaWQ9Ilc1TTBNcENlaGlIenJlU3pOVGN6a2M5ZCI/PiA8eDp4bXBtZXRhIHhtbG5zOng9ImFkb2JlOm5zOm1ldGEvIiB4OnhtcHRrPSJBZG9iZSBYTVAgQ29yZSA1LjYtYzExMSA3OS4xNTgzMjUsIDIwMTUvMDkvMTAtMDE6MTA6MjAgICAgICAgICI+IDxyZGY6UkRGIHhtbG5zOnJkZj0iaHR0cDovL3d3dy53My5vcmcvMTk5OS8wMi8yMi1yZGYtc3ludGF4LW5zIyI+IDxyZGY6RGVzY3JpcHRpb24gcmRmOmFib3V0PSIiIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0UmVmPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VSZWYjIiB4bXA6Q3JlYXRvclRvb2w9IkFkb2JlIFBob3Rvc2hvcCBDQyAyMDE1IChNYWNpbnRvc2gpIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOkIxMjI0OTczNjdCMzExRTZCMkJDRTI0MDgxMDAyMTcxIiB4bXBNTTpEb2N1bWVudElEPSJ4bXAuZGlkOkIxMjI0OTc0NjdCMzExRTZCMkJDRTI0MDgxMDAyMTcxIj4gPHhtcE1NOkRlcml2ZWRGcm9tIHN0UmVmOmluc3RhbmNlSUQ9InhtcC5paWQ6QjEyMjQ5NzE2N0IzMTFFNkIyQkNFMjQwODEwMDIxNzEiIHN0UmVmOmRvY3VtZW50SUQ9InhtcC5kaWQ6QjEyMjQ5NzI2N0IzMTFFNkIyQkNFMjQwODEwMDIxNzEiLz4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz7/7gAOQWRvYmUAZMAAAAAB/9sAhAAbGhopHSlBJiZBQi8vL0JHPz4+P0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHAR0pKTQmND8oKD9HPzU/R0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0dHR0f/wAARCAAIAAgDASIAAhEBAxEB/8QAWQABAQAAAAAAAAAAAAAAAAAAAAYBAQEAAAAAAAAAAAAAAAAAAAIEEAEBAAMBAAAAAAAAAAAAAAABADECA0ERAAEDBQAAAAAAAAAAAAAAAAARITFBUWESIv/aAAwDAQACEQMRAD8AoOnTV1QTD7JJshP3vSM3P//Z) 0 0 #1c1c21; } | ||||||
|  |  | ||||||
|  | .swagger-ui .debug-grid-16-solid { background: url(data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAIAAACQkWg2AAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAyhpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADw/eHBhY2tldCBiZWdpbj0i77u/IiBpZD0iVzVNME1wQ2VoaUh6cmVTek5UY3prYzlkIj8+IDx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IkFkb2JlIFhNUCBDb3JlIDUuNi1jMTExIDc5LjE1ODMyNSwgMjAxNS8wOS8xMC0wMToxMDoyMCAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczp4bXBNTT0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wL21tLyIgeG1sbnM6c3RSZWY9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZVJlZiMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIENDIDIwMTUgKE1hY2ludG9zaCkiIHhtcE1NOkluc3RhbmNlSUQ9InhtcC5paWQ6NzY3MkJEN0U2N0M1MTFFNkIyQkNFMjQwODEwMDIxNzEiIHhtcE1NOkRvY3VtZW50SUQ9InhtcC5kaWQ6NzY3MkJEN0Y2N0M1MTFFNkIyQkNFMjQwODEwMDIxNzEiPiA8eG1wTU06RGVyaXZlZEZyb20gc3RSZWY6aW5zdGFuY2VJRD0ieG1wLmlpZDo3NjcyQkQ3QzY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIgc3RSZWY6ZG9jdW1lbnRJRD0ieG1wLmRpZDo3NjcyQkQ3RDY3QzUxMUU2QjJCQ0UyNDA4MTAwMjE3MSIvPiA8L3JkZjpEZXNjcmlwdGlvbj4gPC9yZGY6UkRGPiA8L3g6eG1wbWV0YT4gPD94cGFja2V0IGVuZD0iciI/Pve6J3kAAAAzSURBVHjaYvz//z8D0UDsMwMjSRoYP5Gq4SPNbRjVMEQ1fCRDg+in/6+J1AJUxsgAEGAA31BAJMS0GYEAAAAASUVORK5CYII=) 0 0 #1c1c21; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black { border-color: #000; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--near-black { border-color: #121212; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--dark-gray { border-color: #333; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--mid-gray { border-color: #545454; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--gray { border-color: #787878; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--silver { border-color: #999; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-silver { border-color: #6e6e6e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--moon-gray { border-color: #4d4d4d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-gray { border-color: #2b2b2b; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--near-white { border-color: #242424; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white { border-color: #1c1c21; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-90 { border-color: rgba(28, 28, 33, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-80 { border-color: rgba(28, 28, 33, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-70 { border-color: rgba(28, 28, 33, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-60 { border-color: rgba(28, 28, 33, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-50 { border-color: rgba(28, 28, 33, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-40 { border-color: rgba(28, 28, 33, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-30 { border-color: rgba(28, 28, 33, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-20 { border-color: rgba(28, 28, 33, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-10 { border-color: rgba(28, 28, 33, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-05 { border-color: rgba(28, 28, 33, .05); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-025 { border-color: rgba(28, 28, 33, .024); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--white-0125 { border-color: rgba(28, 28, 33, .01); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-90 { border-color: rgba(0, 0, 0, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-80 { border-color: rgba(0, 0, 0, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-70 { border-color: rgba(0, 0, 0, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-60 { border-color: rgba(0, 0, 0, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-50 { border-color: rgba(0, 0, 0, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-40 { border-color: rgba(0, 0, 0, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-30 { border-color: rgba(0, 0, 0, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-20 { border-color: rgba(0, 0, 0, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-10 { border-color: rgba(0, 0, 0, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-05 { border-color: rgba(0, 0, 0, .05); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-025 { border-color: rgba(0, 0, 0, .024); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--black-0125 { border-color: rgba(0, 0, 0, .01); } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--dark-red { border-color: #bc2f36; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--red { border-color: #c83932; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-red { border-color: #ab3c2b; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--orange { border-color: #cc6e33; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--purple { border-color: #5e2ca5; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-purple { border-color: #672caf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--dark-pink { border-color: #ab2b81; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--hot-pink { border-color: #c03086; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--pink { border-color: #8f2464; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-pink { border-color: #721d4d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--dark-green { border-color: #1c6e50; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--green { border-color: #279b70; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-green { border-color: #228762; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--navy { border-color: #0d1d35; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--dark-blue { border-color: #20497e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--blue { border-color: #4380d0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--light-blue { border-color: #20517e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--lightest-blue { border-color: #143a52; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--washed-blue { border-color: #0c312d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--washed-green { border-color: #0f3d2c; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--washed-red { border-color: #411010; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--transparent { border-color: transparent; } | ||||||
|  |  | ||||||
|  | .swagger-ui .b--gold, .swagger-ui .b--light-yellow, .swagger-ui .b--washed-yellow, .swagger-ui .b--yellow { border-color: #664b00; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-1 { box-shadow: rgba(0, 0, 0, .2) 0 0 4px 2px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-2 { box-shadow: rgba(0, 0, 0, .2) 0 0 8px 2px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-3 { box-shadow: rgba(0, 0, 0, .2) 2px 2px 4px 2px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-4 { box-shadow: rgba(0, 0, 0, .2) 2px 2px 8px 0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-5 { box-shadow: rgba(0, 0, 0, .2) 4px 4px 8px 0; } | ||||||
|  |  | ||||||
|  | @media screen and (min-width: 30em) { | ||||||
|  |     .swagger-ui .shadow-1-ns { box-shadow: rgba(0, 0, 0, .2) 0 0 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-2-ns { box-shadow: rgba(0, 0, 0, .2) 0 0 8px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-3-ns { box-shadow: rgba(0, 0, 0, .2) 2px 2px 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-4-ns { box-shadow: rgba(0, 0, 0, .2) 2px 2px 8px 0; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-5-ns { box-shadow: rgba(0, 0, 0, .2) 4px 4px 8px 0; } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @media screen and (max-width: 60em) and (min-width: 30em) { | ||||||
|  |     .swagger-ui .shadow-1-m { box-shadow: rgba(0, 0, 0, .2) 0 0 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-2-m { box-shadow: rgba(0, 0, 0, .2) 0 0 8px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-3-m { box-shadow: rgba(0, 0, 0, .2) 2px 2px 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-4-m { box-shadow: rgba(0, 0, 0, .2) 2px 2px 8px 0; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-5-m { box-shadow: rgba(0, 0, 0, .2) 4px 4px 8px 0; } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @media screen and (min-width: 60em) { | ||||||
|  |     .swagger-ui .shadow-1-l { box-shadow: rgba(0, 0, 0, .2) 0 0 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-2-l { box-shadow: rgba(0, 0, 0, .2) 0 0 8px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-3-l { box-shadow: rgba(0, 0, 0, .2) 2px 2px 4px 2px; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-4-l { box-shadow: rgba(0, 0, 0, .2) 2px 2px 8px 0; } | ||||||
|  |  | ||||||
|  |     .swagger-ui .shadow-5-l { box-shadow: rgba(0, 0, 0, .2) 4px 4px 8px 0; } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-05 { color: rgba(191, 191, 191, .05); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-05 { background-color: rgba(0, 0, 0, .05); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-90, .swagger-ui .hover-black-90:focus, .swagger-ui .hover-black-90:hover { color: rgba(191, 191, 191, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-80, .swagger-ui .hover-black-80:focus, .swagger-ui .hover-black-80:hover { color: rgba(191, 191, 191, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-70, .swagger-ui .hover-black-70:focus, .swagger-ui .hover-black-70:hover { color: rgba(191, 191, 191, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-60, .swagger-ui .hover-black-60:focus, .swagger-ui .hover-black-60:hover { color: rgba(191, 191, 191, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-50, .swagger-ui .hover-black-50:focus, .swagger-ui .hover-black-50:hover { color: rgba(191, 191, 191, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-40, .swagger-ui .hover-black-40:focus, .swagger-ui .hover-black-40:hover { color: rgba(191, 191, 191, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-30, .swagger-ui .hover-black-30:focus, .swagger-ui .hover-black-30:hover { color: rgba(191, 191, 191, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-20, .swagger-ui .hover-black-20:focus, .swagger-ui .hover-black-20:hover { color: rgba(191, 191, 191, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .black-10, .swagger-ui .hover-black-10:focus, .swagger-ui .hover-black-10:hover { color: rgba(191, 191, 191, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-90:focus, .swagger-ui .hover-white-90:hover, .swagger-ui .white-90 { color: rgba(255, 255, 255, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-80:focus, .swagger-ui .hover-white-80:hover, .swagger-ui .white-80 { color: rgba(255, 255, 255, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-70:focus, .swagger-ui .hover-white-70:hover, .swagger-ui .white-70 { color: rgba(255, 255, 255, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-60:focus, .swagger-ui .hover-white-60:hover, .swagger-ui .white-60 { color: rgba(255, 255, 255, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-50:focus, .swagger-ui .hover-white-50:hover, .swagger-ui .white-50 { color: rgba(255, 255, 255, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-40:focus, .swagger-ui .hover-white-40:hover, .swagger-ui .white-40 { color: rgba(255, 255, 255, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-30:focus, .swagger-ui .hover-white-30:hover, .swagger-ui .white-30 { color: rgba(255, 255, 255, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-20:focus, .swagger-ui .hover-white-20:hover, .swagger-ui .white-20 { color: rgba(255, 255, 255, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white-10:focus, .swagger-ui .hover-white-10:hover, .swagger-ui .white-10 { color: rgba(255, 255, 255, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-moon-gray:focus, .swagger-ui .hover-moon-gray:hover, .swagger-ui .moon-gray { color: #ccc; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-gray:focus, .swagger-ui .hover-light-gray:hover, .swagger-ui .light-gray { color: #ededed; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-near-white:focus, .swagger-ui .hover-near-white:hover, .swagger-ui .near-white { color: #f5f5f5; } | ||||||
|  |  | ||||||
|  | .swagger-ui .dark-red, .swagger-ui .hover-dark-red:focus, .swagger-ui .hover-dark-red:hover { color: #e6999d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-red:focus, .swagger-ui .hover-red:hover, .swagger-ui .red { color: #e69d99; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-red:focus, .swagger-ui .hover-light-red:hover, .swagger-ui .light-red { color: #e6a399; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-orange:focus, .swagger-ui .hover-orange:hover, .swagger-ui .orange { color: #e6b699; } | ||||||
|  |  | ||||||
|  | .swagger-ui .gold, .swagger-ui .hover-gold:focus, .swagger-ui .hover-gold:hover { color: #e6d099; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-yellow:focus, .swagger-ui .hover-yellow:hover, .swagger-ui .yellow { color: #e6da99; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-yellow:focus, .swagger-ui .hover-light-yellow:hover, .swagger-ui .light-yellow { color: #ede6b6; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-purple:focus, .swagger-ui .hover-purple:hover, .swagger-ui .purple { color: #b99ae4; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-purple:focus, .swagger-ui .hover-light-purple:hover, .swagger-ui .light-purple { color: #bb99e6; } | ||||||
|  |  | ||||||
|  | .swagger-ui .dark-pink, .swagger-ui .hover-dark-pink:focus, .swagger-ui .hover-dark-pink:hover { color: #e699cc; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hot-pink, .swagger-ui .hover-hot-pink:focus, .swagger-ui .hover-hot-pink:hover, .swagger-ui .hover-pink:focus, .swagger-ui .hover-pink:hover, .swagger-ui .pink { color: #e699c7; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-pink:focus, .swagger-ui .hover-light-pink:hover, .swagger-ui .light-pink { color: #edb6d5; } | ||||||
|  |  | ||||||
|  | .swagger-ui .dark-green, .swagger-ui .green, .swagger-ui .hover-dark-green:focus, .swagger-ui .hover-dark-green:hover, .swagger-ui .hover-green:focus, .swagger-ui .hover-green:hover { color: #99e6c9; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-green:focus, .swagger-ui .hover-light-green:hover, .swagger-ui .light-green { color: #a1e8ce; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-navy:focus, .swagger-ui .hover-navy:hover, .swagger-ui .navy { color: #99b8e6; } | ||||||
|  |  | ||||||
|  | .swagger-ui .blue, .swagger-ui .dark-blue, .swagger-ui .hover-blue:focus, .swagger-ui .hover-blue:hover, .swagger-ui .hover-dark-blue:focus, .swagger-ui .hover-dark-blue:hover { color: #99bae6; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-light-blue:focus, .swagger-ui .hover-light-blue:hover, .swagger-ui .light-blue { color: #a9cbea; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-lightest-blue:focus, .swagger-ui .hover-lightest-blue:hover, .swagger-ui .lightest-blue { color: #d6e9f5; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-washed-blue:focus, .swagger-ui .hover-washed-blue:hover, .swagger-ui .washed-blue { color: #f7fdfc; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-washed-green:focus, .swagger-ui .hover-washed-green:hover, .swagger-ui .washed-green { color: #ebfaf4; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-washed-yellow:focus, .swagger-ui .hover-washed-yellow:hover, .swagger-ui .washed-yellow { color: #fbf9ef; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-washed-red:focus, .swagger-ui .hover-washed-red:hover, .swagger-ui .washed-red { color: #f9e7e7; } | ||||||
|  |  | ||||||
|  | .swagger-ui .color-inherit, .swagger-ui .hover-inherit:focus, .swagger-ui .hover-inherit:hover { color: inherit; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-90, .swagger-ui .hover-bg-black-90:focus, .swagger-ui .hover-bg-black-90:hover { background-color: rgba(0, 0, 0, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-80, .swagger-ui .hover-bg-black-80:focus, .swagger-ui .hover-bg-black-80:hover { background-color: rgba(0, 0, 0, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-70, .swagger-ui .hover-bg-black-70:focus, .swagger-ui .hover-bg-black-70:hover { background-color: rgba(0, 0, 0, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-60, .swagger-ui .hover-bg-black-60:focus, .swagger-ui .hover-bg-black-60:hover { background-color: rgba(0, 0, 0, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-50, .swagger-ui .hover-bg-black-50:focus, .swagger-ui .hover-bg-black-50:hover { background-color: rgba(0, 0, 0, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-40, .swagger-ui .hover-bg-black-40:focus, .swagger-ui .hover-bg-black-40:hover { background-color: rgba(0, 0, 0, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-30, .swagger-ui .hover-bg-black-30:focus, .swagger-ui .hover-bg-black-30:hover { background-color: rgba(0, 0, 0, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-20, .swagger-ui .hover-bg-black-20:focus, .swagger-ui .hover-bg-black-20:hover { background-color: rgba(0, 0, 0, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-90, .swagger-ui .hover-bg-white-90:focus, .swagger-ui .hover-bg-white-90:hover { background-color: rgba(28, 28, 33, .9); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-80, .swagger-ui .hover-bg-white-80:focus, .swagger-ui .hover-bg-white-80:hover { background-color: rgba(28, 28, 33, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-70, .swagger-ui .hover-bg-white-70:focus, .swagger-ui .hover-bg-white-70:hover { background-color: rgba(28, 28, 33, .7); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-60, .swagger-ui .hover-bg-white-60:focus, .swagger-ui .hover-bg-white-60:hover { background-color: rgba(28, 28, 33, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-50, .swagger-ui .hover-bg-white-50:focus, .swagger-ui .hover-bg-white-50:hover { background-color: rgba(28, 28, 33, .5); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-40, .swagger-ui .hover-bg-white-40:focus, .swagger-ui .hover-bg-white-40:hover { background-color: rgba(28, 28, 33, .4); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-30, .swagger-ui .hover-bg-white-30:focus, .swagger-ui .hover-bg-white-30:hover { background-color: rgba(28, 28, 33, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-20, .swagger-ui .hover-bg-white-20:focus, .swagger-ui .hover-bg-white-20:hover { background-color: rgba(28, 28, 33, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black, .swagger-ui .hover-bg-black:focus, .swagger-ui .hover-bg-black:hover { background-color: #000; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-near-black, .swagger-ui .hover-bg-near-black:focus, .swagger-ui .hover-bg-near-black:hover { background-color: #121212; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-dark-gray, .swagger-ui .hover-bg-dark-gray:focus, .swagger-ui .hover-bg-dark-gray:hover { background-color: #333; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-mid-gray, .swagger-ui .hover-bg-mid-gray:focus, .swagger-ui .hover-bg-mid-gray:hover { background-color: #545454; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-gray, .swagger-ui .hover-bg-gray:focus, .swagger-ui .hover-bg-gray:hover { background-color: #787878; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-silver, .swagger-ui .hover-bg-silver:focus, .swagger-ui .hover-bg-silver:hover { background-color: #999; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white, .swagger-ui .hover-bg-white:focus, .swagger-ui .hover-bg-white:hover { background-color: #1c1c21; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-transparent, .swagger-ui .hover-bg-transparent:focus, .swagger-ui .hover-bg-transparent:hover { background-color: transparent; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-dark-red, .swagger-ui .hover-bg-dark-red:focus, .swagger-ui .hover-bg-dark-red:hover { background-color: #bc2f36; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-red, .swagger-ui .hover-bg-red:focus, .swagger-ui .hover-bg-red:hover { background-color: #c83932; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-red, .swagger-ui .hover-bg-light-red:focus, .swagger-ui .hover-bg-light-red:hover { background-color: #ab3c2b; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-orange, .swagger-ui .hover-bg-orange:focus, .swagger-ui .hover-bg-orange:hover { background-color: #cc6e33; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-gold, .swagger-ui .bg-light-yellow, .swagger-ui .bg-washed-yellow, .swagger-ui .bg-yellow, .swagger-ui .hover-bg-gold:focus, .swagger-ui .hover-bg-gold:hover, .swagger-ui .hover-bg-light-yellow:focus, .swagger-ui .hover-bg-light-yellow:hover, .swagger-ui .hover-bg-washed-yellow:focus, .swagger-ui .hover-bg-washed-yellow:hover, .swagger-ui .hover-bg-yellow:focus, .swagger-ui .hover-bg-yellow:hover { background-color: #664b00; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-purple, .swagger-ui .hover-bg-purple:focus, .swagger-ui .hover-bg-purple:hover { background-color: #5e2ca5; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-purple, .swagger-ui .hover-bg-light-purple:focus, .swagger-ui .hover-bg-light-purple:hover { background-color: #672caf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-dark-pink, .swagger-ui .hover-bg-dark-pink:focus, .swagger-ui .hover-bg-dark-pink:hover { background-color: #ab2b81; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-hot-pink, .swagger-ui .hover-bg-hot-pink:focus, .swagger-ui .hover-bg-hot-pink:hover { background-color: #c03086; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-pink, .swagger-ui .hover-bg-pink:focus, .swagger-ui .hover-bg-pink:hover { background-color: #8f2464; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-pink, .swagger-ui .hover-bg-light-pink:focus, .swagger-ui .hover-bg-light-pink:hover { background-color: #721d4d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-dark-green, .swagger-ui .hover-bg-dark-green:focus, .swagger-ui .hover-bg-dark-green:hover { background-color: #1c6e50; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-green, .swagger-ui .hover-bg-green:focus, .swagger-ui .hover-bg-green:hover { background-color: #279b70; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-green, .swagger-ui .hover-bg-light-green:focus, .swagger-ui .hover-bg-light-green:hover { background-color: #228762; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-navy, .swagger-ui .hover-bg-navy:focus, .swagger-ui .hover-bg-navy:hover { background-color: #0d1d35; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-dark-blue, .swagger-ui .hover-bg-dark-blue:focus, .swagger-ui .hover-bg-dark-blue:hover { background-color: #20497e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-blue, .swagger-ui .hover-bg-blue:focus, .swagger-ui .hover-bg-blue:hover { background-color: #4380d0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-blue, .swagger-ui .hover-bg-light-blue:focus, .swagger-ui .hover-bg-light-blue:hover { background-color: #20517e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-lightest-blue, .swagger-ui .hover-bg-lightest-blue:focus, .swagger-ui .hover-bg-lightest-blue:hover { background-color: #143a52; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-washed-blue, .swagger-ui .hover-bg-washed-blue:focus, .swagger-ui .hover-bg-washed-blue:hover { background-color: #0c312d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-washed-green, .swagger-ui .hover-bg-washed-green:focus, .swagger-ui .hover-bg-washed-green:hover { background-color: #0f3d2c; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-washed-red, .swagger-ui .hover-bg-washed-red:focus, .swagger-ui .hover-bg-washed-red:hover { background-color: #411010; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-inherit, .swagger-ui .hover-bg-inherit:focus, .swagger-ui .hover-bg-inherit:hover { background-color: inherit; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-hover { transition: all .5s cubic-bezier(.165, .84, .44, 1) 0s; } | ||||||
|  |  | ||||||
|  | .swagger-ui .shadow-hover::after { | ||||||
|  |     border-radius: inherit; | ||||||
|  |     box-shadow: rgba(0, 0, 0, .2) 0 0 16px 2px; | ||||||
|  |     content: ""; | ||||||
|  |     height: 100%; | ||||||
|  |     left: 0; | ||||||
|  |     opacity: 0; | ||||||
|  |     position: absolute; | ||||||
|  |     top: 0; | ||||||
|  |     transition: opacity .5s cubic-bezier(.165, .84, .44, 1) 0s; | ||||||
|  |     width: 100%; | ||||||
|  |     z-index: -1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-animate, .swagger-ui .bg-animate:focus, .swagger-ui .bg-animate:hover { transition: background-color .15s ease-in-out 0s; } | ||||||
|  |  | ||||||
|  | .swagger-ui .nested-links a { | ||||||
|  |     color: #99bae6; | ||||||
|  |     transition: color .15s ease-in 0s; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .nested-links a:focus, .swagger-ui .nested-links a:hover { | ||||||
|  |     color: #a9cbea; | ||||||
|  |     transition: color .15s ease-in 0s; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock-tag { | ||||||
|  |     border-bottom: 1px solid rgba(58, 64, 80, .3); | ||||||
|  |     color: #b5bac9; | ||||||
|  |     transition: all .2s ease 0s; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock-tag svg, .swagger-ui section.models h4 svg { transition: all .4s ease 0s; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock { | ||||||
|  |     border: 1px solid #000; | ||||||
|  |     border-radius: 4px; | ||||||
|  |     box-shadow: rgba(0, 0, 0, .19) 0 0 3px; | ||||||
|  |     margin: 0 0 15px; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock .tab-header .tab-item.active h4 span::after { background: gray; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.is-open .opblock-summary { border-bottom: 1px solid #000; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock .opblock-section-header { | ||||||
|  |     background: rgba(28, 28, 33, .8); | ||||||
|  |     box-shadow: rgba(0, 0, 0, .1) 0 1px 2px; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock .opblock-section-header > label > span { padding: 0 10px 0 0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock .opblock-summary-method { | ||||||
|  |     background: #000; | ||||||
|  |     color: #fff; | ||||||
|  |     text-shadow: rgba(0, 0, 0, .1) 0 1px 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-post { | ||||||
|  |     background: rgba(72, 203, 144, .1); | ||||||
|  |     border-color: #48cb90; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-post .opblock-summary-method, .swagger-ui .opblock.opblock-post .tab-header .tab-item.active h4 span::after { background: #48cb90; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-post .opblock-summary { border-color: #48cb90; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-put { | ||||||
|  |     background: rgba(213, 157, 88, .1); | ||||||
|  |     border-color: #d59d58; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-put .opblock-summary-method, .swagger-ui .opblock.opblock-put .tab-header .tab-item.active h4 span::after { background: #d59d58; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-put .opblock-summary { border-color: #d59d58; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-delete { | ||||||
|  |     background: rgba(200, 50, 50, .1); | ||||||
|  |     border-color: #c83232; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-delete .opblock-summary-method, .swagger-ui .opblock.opblock-delete .tab-header .tab-item.active h4 span::after { background: #c83232; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-delete .opblock-summary { border-color: #c83232; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-get { | ||||||
|  |     background: rgba(42, 105, 167, .1); | ||||||
|  |     border-color: #2a69a7; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-get .opblock-summary-method, .swagger-ui .opblock.opblock-get .tab-header .tab-item.active h4 span::after { background: #2a69a7; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-get .opblock-summary { border-color: #2a69a7; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-patch { | ||||||
|  |     background: rgba(92, 214, 188, .1); | ||||||
|  |     border-color: #5cd6bc; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-patch .opblock-summary-method, .swagger-ui .opblock.opblock-patch .tab-header .tab-item.active h4 span::after { background: #5cd6bc; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-patch .opblock-summary { border-color: #5cd6bc; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-head { | ||||||
|  |     background: rgba(140, 63, 207, .1); | ||||||
|  |     border-color: #8c3fcf; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-head .opblock-summary-method, .swagger-ui .opblock.opblock-head .tab-header .tab-item.active h4 span::after { background: #8c3fcf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-head .opblock-summary { border-color: #8c3fcf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-options { | ||||||
|  |     background: rgba(36, 89, 143, .1); | ||||||
|  |     border-color: #24598f; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-options .opblock-summary-method, .swagger-ui .opblock.opblock-options .tab-header .tab-item.active h4 span::after { background: #24598f; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-options .opblock-summary { border-color: #24598f; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-deprecated { | ||||||
|  |     background: rgba(46, 46, 46, .1); | ||||||
|  |     border-color: #2e2e2e; | ||||||
|  |     opacity: .6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-deprecated .opblock-summary-method, .swagger-ui .opblock.opblock-deprecated .tab-header .tab-item.active h4 span::after { background: #2e2e2e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock.opblock-deprecated .opblock-summary { border-color: #2e2e2e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .filter .operation-filter-input { border: 2px solid #2b3446; } | ||||||
|  |  | ||||||
|  | .swagger-ui .tab li:first-of-type::after { background: rgba(0, 0, 0, .2); } | ||||||
|  |  | ||||||
|  | .swagger-ui .download-contents { | ||||||
|  |     background: #7c8192; | ||||||
|  |     color: #fff; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .scheme-container { | ||||||
|  |     background: #1c1c21; | ||||||
|  |     box-shadow: rgba(0, 0, 0, .15) 0 1px 2px 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .loading-container .loading::before { | ||||||
|  |     animation: 1s linear 0s infinite normal none running rotation, .5s ease 0s 1 normal none running opacity; | ||||||
|  |     border-color: rgba(0, 0, 0, .6) rgba(84, 84, 84, .1) rgba(84, 84, 84, .1); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .response-control-media-type--accept-controller select { border-color: #196619; } | ||||||
|  |  | ||||||
|  | .swagger-ui .response-control-media-type__accept-message { color: #99e699; } | ||||||
|  |  | ||||||
|  | .swagger-ui .version-pragma__message code { background-color: #3b3b3b; } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn { | ||||||
|  |     background: 0 0; | ||||||
|  |     border: 2px solid gray; | ||||||
|  |     box-shadow: rgba(0, 0, 0, .1) 0 1px 2px; | ||||||
|  |     color: #b5bac9; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn:hover { box-shadow: rgba(0, 0, 0, .3) 0 0 5px; } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn.authorize, .swagger-ui .btn.cancel { | ||||||
|  |     background-color: transparent; | ||||||
|  |     border-color: #a72a2a; | ||||||
|  |     color: #e69999; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn.authorize { | ||||||
|  |     border-color: #48cb90; | ||||||
|  |     color: #9ce3c3; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn.authorize svg { fill: #9ce3c3; } | ||||||
|  |  | ||||||
|  | .swagger-ui .btn.execute { | ||||||
|  |     background-color: #5892d5; | ||||||
|  |     border-color: #5892d5; | ||||||
|  |     color: #fff; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .copy-to-clipboard { background: #7c8192; } | ||||||
|  |  | ||||||
|  | .swagger-ui .copy-to-clipboard button { background: url("data:image/svg+xml;charset=utf-8,<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"16\" height=\"16\" aria-hidden=\"true\"><path fill=\"%23fff\" fill-rule=\"evenodd\" d=\"M2 13h4v1H2v-1zm5-6H2v1h5V7zm2 3V8l-3 3 3 3v-2h5v-2H9zM4.5 9H2v1h2.5V9zM2 12h2.5v-1H2v1zm9 1h1v2c-.02.28-.11.52-.3.7-.19.18-.42.28-.7.3H1c-.55 0-1-.45-1-1V4c0-.55.45-1 1-1h3c0-1.11.89-2 2-2 1.11 0 2 .89 2 2h3c.55 0 1 .45 1 1v5h-1V6H1v9h10v-2zM2 5h8c0-.55-.45-1-1-1H8c-.55 0-1-.45-1-1s-.45-1-1-1-1 .45-1 1-.45 1-1 1H3c-.55 0-1 .45-1 1z\"/></svg>") 50% center no-repeat; } | ||||||
|  |  | ||||||
|  | .swagger-ui select { | ||||||
|  |     background: url("data:image/svg+xml;charset=utf-8,<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 20 20\"><path d=\"M13.418 7.859a.695.695 0 01.978 0 .68.68 0 010 .969l-3.908 3.83a.697.697 0 01-.979 0l-3.908-3.83a.68.68 0 010-.969.695.695 0 01.978 0L10 11l3.418-3.141z\"/></svg>") right 10px center/20px no-repeat #212121; | ||||||
|  |     background: url(data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiIHN0YW5kYWxvbmU9Im5vIj8+CjxzdmcKICAgeG1sbnM6ZGM9Imh0dHA6Ly9wdXJsLm9yZy9kYy9lbGVtZW50cy8xLjEvIgogICB4bWxuczpjYz0iaHR0cDovL2NyZWF0aXZlY29tbW9ucy5vcmcvbnMjIgogICB4bWxuczpyZGY9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkvMDIvMjItcmRmLXN5bnRheC1ucyMiCiAgIHhtbG5zOnN2Zz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiAgIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIKICAgeG1sbnM6c29kaXBvZGk9Imh0dHA6Ly9zb2RpcG9kaS5zb3VyY2Vmb3JnZS5uZXQvRFREL3NvZGlwb2RpLTAuZHRkIgogICB4bWxuczppbmtzY2FwZT0iaHR0cDovL3d3dy5pbmtzY2FwZS5vcmcvbmFtZXNwYWNlcy9pbmtzY2FwZSIKICAgaW5rc2NhcGU6dmVyc2lvbj0iMS4wICg0MDM1YTRmYjQ5LCAyMDIwLTA1LTAxKSIKICAgc29kaXBvZGk6ZG9jbmFtZT0iZG93bmxvYWQuc3ZnIgogICBpZD0ic3ZnNCIKICAgdmVyc2lvbj0iMS4xIgogICB2aWV3Qm94PSIwIDAgMjAgMjAiPgogIDxtZXRhZGF0YQogICAgIGlkPSJtZXRhZGF0YTEwIj4KICAgIDxyZGY6UkRGPgogICAgICA8Y2M6V29yawogICAgICAgICByZGY6YWJvdXQ9IiI+CiAgICAgICAgPGRjOmZvcm1hdD5pbWFnZS9zdmcreG1sPC9kYzpmb3JtYXQ+CiAgICAgICAgPGRjOnR5cGUKICAgICAgICAgICByZGY6cmVzb3VyY2U9Imh0dHA6Ly9wdXJsLm9yZy9kYy9kY21pdHlwZS9TdGlsbEltYWdlIiAvPgogICAgICA8L2NjOldvcms+CiAgICA8L3JkZjpSREY+CiAgPC9tZXRhZGF0YT4KICA8ZGVmcwogICAgIGlkPSJkZWZzOCIgLz4KICA8c29kaXBvZGk6bmFtZWR2aWV3CiAgICAgaW5rc2NhcGU6Y3VycmVudC1sYXllcj0ic3ZnNCIKICAgICBpbmtzY2FwZTp3aW5kb3ctbWF4aW1pemVkPSIxIgogICAgIGlua3NjYXBlOndpbmRvdy15PSItOSIKICAgICBpbmtzY2FwZTp3aW5kb3cteD0iLTkiCiAgICAgaW5rc2NhcGU6Y3k9IjEwIgogICAgIGlua3NjYXBlOmN4PSIxMCIKICAgICBpbmtzY2FwZTp6b29tPSI0MS41IgogICAgIHNob3dncmlkPSJmYWxzZSIKICAgICBpZD0ibmFtZWR2aWV3NiIKICAgICBpbmtzY2FwZTp3aW5kb3ctaGVpZ2h0PSIxMDAxIgogICAgIGlua3NjYXBlOndpbmRvdy13aWR0aD0iMTkyMCIKICAgICBpbmtzY2FwZTpwYWdlc2hhZG93PSIyIgogICAgIGlua3NjYXBlOnBhZ2VvcGFjaXR5PSIwIgogICAgIGd1aWRldG9sZXJhbmNlPSIxMCIKICAgICBncmlkdG9sZXJhbmNlPSIxMCIKICAgICBvYmplY3R0b2xlcmFuY2U9IjEwIgogICAgIGJvcmRlcm9wYWNpdHk9IjEiCiAgICAgYm9yZGVyY29sb3I9IiM2NjY2NjYiCiAgICAgcGFnZWNvbG9yPSIjZmZmZmZmIiAvPgogIDxwYXRoCiAgICAgc3R5bGU9ImZpbGw6I2ZmZmZmZiIKICAgICBpZD0icGF0aDIiCiAgICAgZD0iTTEzLjQxOCA3Ljg1OWEuNjk1LjY5NSAwIDAxLjk3OCAwIC42OC42OCAwIDAxMCAuOTY5bC0zLjkwOCAzLjgzYS42OTcuNjk3IDAgMDEtLjk3OSAwbC0zLjkwOC0zLjgzYS42OC42OCAwIDAxMC0uOTY5LjY5NS42OTUgMCAwMS45NzggMEwxMCAxMWwzLjQxOC0zLjE0MXoiIC8+Cjwvc3ZnPgo=) right 10px center/20px no-repeat #1c1c21; | ||||||
|  |     border: 2px solid #41444e; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui select[multiple] { background: #212121; } | ||||||
|  |  | ||||||
|  | .swagger-ui button.invalid, .swagger-ui input[type=email].invalid, .swagger-ui input[type=file].invalid, .swagger-ui input[type=password].invalid, .swagger-ui input[type=search].invalid, .swagger-ui input[type=text].invalid, .swagger-ui select.invalid, .swagger-ui textarea.invalid { | ||||||
|  |     background: #390e0e; | ||||||
|  |     border-color: #c83232; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui input[type=email], .swagger-ui input[type=file], .swagger-ui input[type=password], .swagger-ui input[type=search], .swagger-ui input[type=text], .swagger-ui textarea { | ||||||
|  |     background: #1c1c21; | ||||||
|  |     border: 1px solid #404040; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui textarea { | ||||||
|  |     background: rgba(28, 28, 33, .8); | ||||||
|  |     color: #b5bac9; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui input[disabled], .swagger-ui select[disabled] { | ||||||
|  |     background-color: #1f1f1f; | ||||||
|  |     color: #bfbfbf; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui textarea[disabled] { | ||||||
|  |     background-color: #41444e; | ||||||
|  |     color: #fff; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui select[disabled] { border-color: #878787; } | ||||||
|  |  | ||||||
|  | .swagger-ui textarea:focus { border: 2px solid #2a69a7; } | ||||||
|  |  | ||||||
|  | .swagger-ui .checkbox input[type=checkbox] + label > .item { | ||||||
|  |     background: #303030; | ||||||
|  |     box-shadow: #303030 0 0 0 2px; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .checkbox input[type=checkbox]:checked + label > .item { background: url("data:image/svg+xml;charset=utf-8,<svg width=\"10\" height=\"8\" viewBox=\"3 7 10 8\" xmlns=\"http://www.w3.org/2000/svg\"><path fill=\"%2341474E\" fill-rule=\"evenodd\" d=\"M6.333 15L3 11.667l1.333-1.334 2 2L11.667 7 13 8.333z\"/></svg>") 50% center no-repeat #303030; } | ||||||
|  |  | ||||||
|  | .swagger-ui .dialog-ux .backdrop-ux { background: rgba(0, 0, 0, .8); } | ||||||
|  |  | ||||||
|  | .swagger-ui .dialog-ux .modal-ux { | ||||||
|  |     background: #1c1c21; | ||||||
|  |     border: 1px solid #2e2e2e; | ||||||
|  |     box-shadow: rgba(0, 0, 0, .2) 0 10px 30px 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .dialog-ux .modal-ux-header .close-modal { background: 0 0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .model .deprecated span, .swagger-ui .model .deprecated td { color: #bfbfbf !important; } | ||||||
|  |  | ||||||
|  | .swagger-ui .model-toggle::after { background: url("data:image/svg+xml;charset=utf-8,<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"24\" height=\"24\"><path d=\"M10 6L8.59 7.41 13.17 12l-4.58 4.59L10 18l6-6z\"/></svg>") 50% center/100% no-repeat; } | ||||||
|  |  | ||||||
|  | .swagger-ui .model-hint { | ||||||
|  |     background: rgba(0, 0, 0, .7); | ||||||
|  |     color: #ebebeb; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui section.models { border: 1px solid rgba(58, 64, 80, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui section.models.is-open h4 { border-bottom: 1px solid rgba(58, 64, 80, .3); } | ||||||
|  |  | ||||||
|  | .swagger-ui section.models .model-container { background: rgba(0, 0, 0, .05); } | ||||||
|  |  | ||||||
|  | .swagger-ui section.models .model-container:hover { background: rgba(0, 0, 0, .07); } | ||||||
|  |  | ||||||
|  | .swagger-ui .model-box { background: rgba(0, 0, 0, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .prop-type { color: #aaaad4; } | ||||||
|  |  | ||||||
|  | .swagger-ui table thead tr td, .swagger-ui table thead tr th { | ||||||
|  |     border-bottom: 1px solid rgba(58, 64, 80, .2); | ||||||
|  |     color: #b5bac9; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .parameter__name.required::after { color: rgba(230, 153, 153, .6); } | ||||||
|  |  | ||||||
|  | .swagger-ui .topbar .download-url-wrapper .select-label { color: #f0f0f0; } | ||||||
|  |  | ||||||
|  | .swagger-ui .topbar .download-url-wrapper .download-url-button { | ||||||
|  |     background: #63a040; | ||||||
|  |     color: #fff; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .info .title small { background: #7c8492; } | ||||||
|  |  | ||||||
|  | .swagger-ui .info .title small.version-stamp { background-color: #7a9b27; } | ||||||
|  |  | ||||||
|  | .swagger-ui .auth-container .errors { | ||||||
|  |     background-color: #350d0d; | ||||||
|  |     color: #b5bac9; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .errors-wrapper { | ||||||
|  |     background: rgba(200, 50, 50, .1); | ||||||
|  |     border: 2px solid #c83232; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .markdown code, .swagger-ui .renderedmarkdown code { | ||||||
|  |     background: rgba(0, 0, 0, .05); | ||||||
|  |     color: #c299e6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .model-toggle:after { background: url(data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiIHN0YW5kYWxvbmU9Im5vIj8+CjxzdmcKICAgeG1sbnM6ZGM9Imh0dHA6Ly9wdXJsLm9yZy9kYy9lbGVtZW50cy8xLjEvIgogICB4bWxuczpjYz0iaHR0cDovL2NyZWF0aXZlY29tbW9ucy5vcmcvbnMjIgogICB4bWxuczpyZGY9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkvMDIvMjItcmRmLXN5bnRheC1ucyMiCiAgIHhtbG5zOnN2Zz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciCiAgIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIKICAgeG1sbnM6c29kaXBvZGk9Imh0dHA6Ly9zb2RpcG9kaS5zb3VyY2Vmb3JnZS5uZXQvRFREL3NvZGlwb2RpLTAuZHRkIgogICB4bWxuczppbmtzY2FwZT0iaHR0cDovL3d3dy5pbmtzY2FwZS5vcmcvbmFtZXNwYWNlcy9pbmtzY2FwZSIKICAgaW5rc2NhcGU6dmVyc2lvbj0iMS4wICg0MDM1YTRmYjQ5LCAyMDIwLTA1LTAxKSIKICAgc29kaXBvZGk6ZG9jbmFtZT0iZG93bmxvYWQyLnN2ZyIKICAgaWQ9InN2ZzQiCiAgIHZlcnNpb249IjEuMSIKICAgaGVpZ2h0PSIyNCIKICAgd2lkdGg9IjI0Ij4KICA8bWV0YWRhdGEKICAgICBpZD0ibWV0YWRhdGExMCI+CiAgICA8cmRmOlJERj4KICAgICAgPGNjOldvcmsKICAgICAgICAgcmRmOmFib3V0PSIiPgogICAgICAgIDxkYzpmb3JtYXQ+aW1hZ2Uvc3ZnK3htbDwvZGM6Zm9ybWF0PgogICAgICAgIDxkYzp0eXBlCiAgICAgICAgICAgcmRmOnJlc291cmNlPSJodHRwOi8vcHVybC5vcmcvZGMvZGNtaXR5cGUvU3RpbGxJbWFnZSIgLz4KICAgICAgPC9jYzpXb3JrPgogICAgPC9yZGY6UkRGPgogIDwvbWV0YWRhdGE+CiAgPGRlZnMKICAgICBpZD0iZGVmczgiIC8+CiAgPHNvZGlwb2RpOm5hbWVkdmlldwogICAgIGlua3NjYXBlOmN1cnJlbnQtbGF5ZXI9InN2ZzQiCiAgICAgaW5rc2NhcGU6d2luZG93LW1heGltaXplZD0iMSIKICAgICBpbmtzY2FwZTp3aW5kb3cteT0iLTkiCiAgICAgaW5rc2NhcGU6d2luZG93LXg9Ii05IgogICAgIGlua3NjYXBlOmN5PSIxMiIKICAgICBpbmtzY2FwZTpjeD0iMTIiCiAgICAgaW5rc2NhcGU6em9vbT0iMzQuNTgzMzMzIgogICAgIHNob3dncmlkPSJmYWxzZSIKICAgICBpZD0ibmFtZWR2aWV3NiIKICAgICBpbmtzY2FwZTp3aW5kb3ctaGVpZ2h0PSIxMDAxIgogICAgIGlua3NjYXBlOndpbmRvdy13aWR0aD0iMTkyMCIKICAgICBpbmtzY2FwZTpwYWdlc2hhZG93PSIyIgogICAgIGlua3NjYXBlOnBhZ2VvcGFjaXR5PSIwIgogICAgIGd1aWRldG9sZXJhbmNlPSIxMCIKICAgICBncmlkdG9sZXJhbmNlPSIxMCIKICAgICBvYmplY3R0b2xlcmFuY2U9IjEwIgogICAgIGJvcmRlcm9wYWNpdHk9IjEiCiAgICAgYm9yZGVyY29sb3I9IiM2NjY2NjYiCiAgICAgcGFnZWNvbG9yPSIjZmZmZmZmIiAvPgogIDxwYXRoCiAgICAgc3R5bGU9ImZpbGw6I2ZmZmZmZiIKICAgICBpZD0icGF0aDIiCiAgICAgZD0iTTEwIDZMOC41OSA3LjQxIDEzLjE3IDEybC00LjU4IDQuNTlMMTAgMThsNi02eiIgLz4KPC9zdmc+Cg==) 50% no-repeat; } | ||||||
|  |  | ||||||
|  | .swagger-ui .expand-operation svg, .swagger-ui section.models h4 svg { fill: #fff; } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-track { background-color: #646464 !important; } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-thumb { | ||||||
|  |     background-color: #242424 !important; | ||||||
|  |     border: 2px solid #3e4346 !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:vertical:start:decrement { | ||||||
|  |     background: linear-gradient(130deg, #696969 40%, rgba(255, 0, 0, 0) 41%), linear-gradient(230deg, #696969 40%, transparent 41%), linear-gradient(0deg, #696969 40%, transparent 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:vertical:end:increment { | ||||||
|  |     background: linear-gradient(310deg, #696969 40%, transparent 41%), linear-gradient(50deg, #696969 40%, transparent 41%), linear-gradient(180deg, #696969 40%, transparent 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:horizontal:end:increment { | ||||||
|  |     background: linear-gradient(210deg, #696969 40%, transparent 41%), linear-gradient(330deg, #696969 40%, transparent 41%), linear-gradient(90deg, #696969 30%, transparent 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:horizontal:start:decrement { | ||||||
|  |     background: linear-gradient(30deg, #696969 40%, transparent 41%), linear-gradient(150deg, #696969 40%, transparent 41%), linear-gradient(270deg, #696969 30%, transparent 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button, ::-webkit-scrollbar-track-piece { background-color: #3e4346 !important; } | ||||||
|  |  | ||||||
|  | .swagger-ui .black, .swagger-ui .checkbox, .swagger-ui .dark-gray, .swagger-ui .download-url-wrapper .loading, .swagger-ui .errors-wrapper .errors small, .swagger-ui .fallback, .swagger-ui .filter .loading, .swagger-ui .gray, .swagger-ui .hover-black:focus, .swagger-ui .hover-black:hover, .swagger-ui .hover-dark-gray:focus, .swagger-ui .hover-dark-gray:hover, .swagger-ui .hover-gray:focus, .swagger-ui .hover-gray:hover, .swagger-ui .hover-light-silver:focus, .swagger-ui .hover-light-silver:hover, .swagger-ui .hover-mid-gray:focus, .swagger-ui .hover-mid-gray:hover, .swagger-ui .hover-near-black:focus, .swagger-ui .hover-near-black:hover, .swagger-ui .hover-silver:focus, .swagger-ui .hover-silver:hover, .swagger-ui .light-silver, .swagger-ui .markdown pre, .swagger-ui .mid-gray, .swagger-ui .model .property, .swagger-ui .model .property.primitive, .swagger-ui .model-title, .swagger-ui .near-black, .swagger-ui .parameter__extension, .swagger-ui .parameter__in, .swagger-ui .prop-format, .swagger-ui .renderedmarkdown pre, .swagger-ui .response-col_links .response-undocumented, .swagger-ui .response-col_status .response-undocumented, .swagger-ui .silver, .swagger-ui section.models h4, .swagger-ui section.models h5, .swagger-ui span.token-not-formatted, .swagger-ui span.token-string, .swagger-ui table.headers .header-example, .swagger-ui table.model tr.description, .swagger-ui table.model tr.extension { color: #bfbfbf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .hover-white:focus, .swagger-ui .hover-white:hover, .swagger-ui .info .title small pre, .swagger-ui .topbar a, .swagger-ui .white { color: #fff; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-black-10, .swagger-ui .hover-bg-black-10:focus, .swagger-ui .hover-bg-black-10:hover, .swagger-ui .stripe-dark:nth-child(2n + 1) { background-color: rgba(0, 0, 0, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-white-10, .swagger-ui .hover-bg-white-10:focus, .swagger-ui .hover-bg-white-10:hover, .swagger-ui .stripe-light:nth-child(2n + 1) { background-color: rgba(28, 28, 33, .1); } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-silver, .swagger-ui .hover-bg-light-silver:focus, .swagger-ui .hover-bg-light-silver:hover, .swagger-ui .striped--light-silver:nth-child(2n + 1) { background-color: #6e6e6e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-moon-gray, .swagger-ui .hover-bg-moon-gray:focus, .swagger-ui .hover-bg-moon-gray:hover, .swagger-ui .striped--moon-gray:nth-child(2n + 1) { background-color: #4d4d4d; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-light-gray, .swagger-ui .hover-bg-light-gray:focus, .swagger-ui .hover-bg-light-gray:hover, .swagger-ui .striped--light-gray:nth-child(2n + 1) { background-color: #2b2b2b; } | ||||||
|  |  | ||||||
|  | .swagger-ui .bg-near-white, .swagger-ui .hover-bg-near-white:focus, .swagger-ui .hover-bg-near-white:hover, .swagger-ui .striped--near-white:nth-child(2n + 1) { background-color: #242424; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock-tag:hover, .swagger-ui section.models h4:hover { background: rgba(0, 0, 0, .02); } | ||||||
|  |  | ||||||
|  | .swagger-ui .checkbox p, .swagger-ui .dialog-ux .modal-ux-content h4, .swagger-ui .dialog-ux .modal-ux-content p, .swagger-ui .dialog-ux .modal-ux-header h3, .swagger-ui .errors-wrapper .errors h4, .swagger-ui .errors-wrapper hgroup h4, .swagger-ui .info .base-url, .swagger-ui .info .title, .swagger-ui .info h1, .swagger-ui .info h2, .swagger-ui .info h3, .swagger-ui .info h4, .swagger-ui .info h5, .swagger-ui .info li, .swagger-ui .info p, .swagger-ui .info table, .swagger-ui .loading-container .loading::after, .swagger-ui .model, .swagger-ui .opblock .opblock-section-header h4, .swagger-ui .opblock .opblock-section-header > label, .swagger-ui .opblock .opblock-summary-description, .swagger-ui .opblock .opblock-summary-operation-id, .swagger-ui .opblock .opblock-summary-path, .swagger-ui .opblock .opblock-summary-path__deprecated, .swagger-ui .opblock-description-wrapper, .swagger-ui .opblock-description-wrapper h4, .swagger-ui .opblock-description-wrapper p, .swagger-ui .opblock-external-docs-wrapper, .swagger-ui .opblock-external-docs-wrapper h4, .swagger-ui .opblock-external-docs-wrapper p, .swagger-ui .opblock-tag small, .swagger-ui .opblock-title_normal, .swagger-ui .opblock-title_normal h4, .swagger-ui .opblock-title_normal p, .swagger-ui .parameter__name, .swagger-ui .parameter__type, .swagger-ui .response-col_links, .swagger-ui .response-col_status, .swagger-ui .responses-inner h4, .swagger-ui .responses-inner h5, .swagger-ui .scheme-container .schemes > label, .swagger-ui .scopes h2, .swagger-ui .servers > label, .swagger-ui .tab li, .swagger-ui label, .swagger-ui select, .swagger-ui table.headers td { color: #b5bac9; } | ||||||
|  |  | ||||||
|  | .swagger-ui .download-url-wrapper .failed, .swagger-ui .filter .failed, .swagger-ui .model-deprecated-warning, .swagger-ui .parameter__deprecated, .swagger-ui .parameter__name.required span, .swagger-ui table.model tr.property-row .star { color: #e69999; } | ||||||
|  |  | ||||||
|  | .swagger-ui .opblock-body pre.microlight, .swagger-ui textarea.curl { | ||||||
|  |     background: #41444e; | ||||||
|  |     border-radius: 4px; | ||||||
|  |     color: #fff; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | .swagger-ui .expand-methods svg, .swagger-ui .expand-methods:hover svg { fill: #bfbfbf; } | ||||||
|  |  | ||||||
|  | .swagger-ui .auth-container, .swagger-ui .dialog-ux .modal-ux-header { border-bottom: 1px solid #2e2e2e; } | ||||||
|  |  | ||||||
|  | .swagger-ui .topbar .download-url-wrapper .select-label select, .swagger-ui .topbar .download-url-wrapper input[type=text] { border: 2px solid #63a040; } | ||||||
|  |  | ||||||
|  | .swagger-ui .info a, .swagger-ui .info a:hover, .swagger-ui .scopes h2 a { color: #99bde6; } | ||||||
|  |  | ||||||
|  | /* Dark Scrollbar */ | ||||||
|  | ::-webkit-scrollbar { | ||||||
|  |     width: 14px; | ||||||
|  |     height: 14px; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button { | ||||||
|  |     background-color: #3e4346 !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-track { | ||||||
|  |     background-color: #646464 !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-track-piece { | ||||||
|  |     background-color: #3e4346 !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-thumb { | ||||||
|  |     height: 50px; | ||||||
|  |     background-color: #242424 !important; | ||||||
|  |     border: 2px solid #3e4346 !important; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-corner {} | ||||||
|  |  | ||||||
|  | ::-webkit-resizer {} | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:vertical:start:decrement { | ||||||
|  |     background: | ||||||
|  |         linear-gradient(130deg, #696969 40%, rgba(255, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(230deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(0deg, #696969 40%, rgba(0, 0, 0, 0) 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:vertical:end:increment { | ||||||
|  |     background: | ||||||
|  |         linear-gradient(310deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(50deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(180deg, #696969 40%, rgba(0, 0, 0, 0) 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:horizontal:end:increment { | ||||||
|  |     background: | ||||||
|  |         linear-gradient(210deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(330deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(90deg, #696969 30%, rgba(0, 0, 0, 0) 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | ::-webkit-scrollbar-button:horizontal:start:decrement { | ||||||
|  |     background: | ||||||
|  |         linear-gradient(30deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(150deg, #696969 40%, rgba(0, 0, 0, 0) 41%), | ||||||
|  |         linear-gradient(270deg, #696969 30%, rgba(0, 0, 0, 0) 31%); | ||||||
|  |     background-color: #b6b6b6; | ||||||
|  | } | ||||||
							
								
								
									
										17
									
								
								static/swagger-ui/index.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								static/swagger-ui/index.css
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | /*! Swagger UI 4.13.2 | https://swagger.io/tools/swagger-ui/ | Apache License 2.0 (license file can be found at ./LICENSE) */ | ||||||
|  | html { | ||||||
|  |     box-sizing: border-box; | ||||||
|  |     overflow: -moz-scrollbars-vertical; | ||||||
|  |     overflow-y: scroll; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | *, | ||||||
|  | *:before, | ||||||
|  | *:after { | ||||||
|  |     box-sizing: inherit; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | body { | ||||||
|  |     margin: 0; | ||||||
|  |     background: #fafafa; | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								static/swagger-ui/oauth2-redirect.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								static/swagger-ui/oauth2-redirect.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | <!doctype html> | ||||||
|  | <html lang="en-US"> | ||||||
|  | <head> | ||||||
|  |     <title>Swagger UI: OAuth2 Redirect</title> | ||||||
|  | </head> | ||||||
|  | <body> | ||||||
|  | <script> | ||||||
|  |     'use strict'; | ||||||
|  |     function run () { | ||||||
|  |         var oauth2 = window.opener.swaggerUIRedirectOauth2; | ||||||
|  |         var sentState = oauth2.state; | ||||||
|  |         var redirectUrl = oauth2.redirectUrl; | ||||||
|  |         var isValid, qp, arr; | ||||||
|  |  | ||||||
|  |         if (/code|token|error/.test(window.location.hash)) { | ||||||
|  |             qp = window.location.hash.substring(1); | ||||||
|  |         } else { | ||||||
|  |             qp = location.search.substring(1); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         arr = qp.split("&"); | ||||||
|  |         arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';}); | ||||||
|  |         qp = qp ? JSON.parse('{' + arr.join() + '}', | ||||||
|  |                 function (key, value) { | ||||||
|  |                     return key === "" ? value : decodeURIComponent(value); | ||||||
|  |                 } | ||||||
|  |         ) : {}; | ||||||
|  |  | ||||||
|  |         isValid = qp.state === sentState; | ||||||
|  |  | ||||||
|  |         if (( | ||||||
|  |           oauth2.auth.schema.get("flow") === "accessCode" || | ||||||
|  |           oauth2.auth.schema.get("flow") === "authorizationCode" || | ||||||
|  |           oauth2.auth.schema.get("flow") === "authorization_code" | ||||||
|  |         ) && !oauth2.auth.code) { | ||||||
|  |             if (!isValid) { | ||||||
|  |                 oauth2.errCb({ | ||||||
|  |                     authId: oauth2.auth.name, | ||||||
|  |                     source: "auth", | ||||||
|  |                     level: "warning", | ||||||
|  |                     message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server." | ||||||
|  |                 }); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (qp.code) { | ||||||
|  |                 delete oauth2.state; | ||||||
|  |                 oauth2.auth.code = qp.code; | ||||||
|  |                 oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl}); | ||||||
|  |             } else { | ||||||
|  |                 let oauthErrorMsg; | ||||||
|  |                 if (qp.error) { | ||||||
|  |                     oauthErrorMsg = "["+qp.error+"]: " + | ||||||
|  |                         (qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") + | ||||||
|  |                         (qp.error_uri ? "More info: "+qp.error_uri : ""); | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 oauth2.errCb({ | ||||||
|  |                     authId: oauth2.auth.name, | ||||||
|  |                     source: "auth", | ||||||
|  |                     level: "error", | ||||||
|  |                     message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server." | ||||||
|  |                 }); | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl}); | ||||||
|  |         } | ||||||
|  |         window.close(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (document.readyState !== 'loading') { | ||||||
|  |         run(); | ||||||
|  |     } else { | ||||||
|  |         document.addEventListener('DOMContentLoaded', function () { | ||||||
|  |             run(); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | </script> | ||||||
|  | </body> | ||||||
|  | </html> | ||||||
							
								
								
									
										2
									
								
								static/swagger-ui/swagger-ui-bundle.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								static/swagger-ui/swagger-ui-bundle.js
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										3
									
								
								static/swagger-ui/swagger-ui.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								static/swagger-ui/swagger-ui.css
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										2
									
								
								static/swagger-ui/swagger-ui.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								static/swagger-ui/swagger-ui.js
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										35
									
								
								templates/swagger-ui.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								templates/swagger-ui.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | |||||||
|  | {# This is the HTML template for Swagger UI (the GUI for the API documentation at /api/latest/docs) #} | ||||||
|  | <!DOCTYPE html> | ||||||
|  | <html lang="en"> | ||||||
|  | 	<head> | ||||||
|  | 		<title>KoboldAI API</title> | ||||||
|  | 		<meta charset="UTF-8"> | ||||||
|  | 		<link rel="stylesheet" type="text/css" href="/static/swagger-ui/swagger-ui.css" /> | ||||||
|  | 		<link rel="stylesheet" type="text/css" href="/static/swagger-ui/index.css" /> | ||||||
|  | 		<script> | ||||||
|  | 			if (window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)").matches) document.write('<link rel="stylesheet" type="text/css" href="/static/swagger-ui/SwaggerDark.css" />'); | ||||||
|  | 		</script> | ||||||
|  | 	</head> | ||||||
|  | 	<body> | ||||||
|  | 		<div id="swagger-ui"></div> | ||||||
|  | 		<script src="/static/swagger-ui/swagger-ui-bundle.js" charset="UTF-8"></script> | ||||||
|  | 		<script> | ||||||
|  | 			window.onload = function() { | ||||||
|  | 				 | ||||||
|  | 				window.ui = SwaggerUIBundle({ | ||||||
|  | 					url: "{{ url }}", | ||||||
|  | 					oauth2RedirectUrl: "/static/swagger-ui/oauth2-redirect.html", | ||||||
|  | 					dom_id: "#swagger-ui", | ||||||
|  | 					deepLinking: true, | ||||||
|  | 					presets: [ | ||||||
|  | 						SwaggerUIBundle.presets.apis | ||||||
|  | 					], | ||||||
|  | 					plugins: [ | ||||||
|  | 						SwaggerUIBundle.plugins.DownloadUrl | ||||||
|  | 					], | ||||||
|  | 					layout: "BaseLayout" | ||||||
|  | 				}); | ||||||
|  | 			}; | ||||||
|  | 		</script> | ||||||
|  | 	</body> | ||||||
|  | </html> | ||||||
		Reference in New Issue
	
	Block a user