mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Upload basic API with /generate POST endpoint
This commit is contained in:
468
aiserver.py
468
aiserver.py
@ -36,8 +36,9 @@ import itertools
|
||||
import bisect
|
||||
import functools
|
||||
import traceback
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
|
||||
|
||||
import requests
|
||||
import html
|
||||
@ -352,6 +353,10 @@ class vars:
|
||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
||||
revision = None
|
||||
output_streaming = False
|
||||
standalone = False
|
||||
disable_set_aibusy = False
|
||||
disable_input_formatting = False
|
||||
disable_output_formatting = False
|
||||
token_stream_queue = [] # Queue for the token streaming
|
||||
|
||||
utils.vars = vars
|
||||
@ -372,9 +377,11 @@ log.setLevel(logging.ERROR)
|
||||
|
||||
# Start flask & SocketIO
|
||||
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
|
||||
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session
|
||||
from flask_socketio import SocketIO, emit
|
||||
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session, jsonify, abort
|
||||
from flask_socketio import SocketIO
|
||||
from flask_socketio import emit as _emit
|
||||
from flask_session import Session
|
||||
from werkzeug.exceptions import HTTPException, ServiceUnavailable
|
||||
import secrets
|
||||
app = Flask(__name__, root_path=os.getcwd())
|
||||
app.secret_key = secrets.token_hex()
|
||||
@ -384,6 +391,144 @@ Session(app)
|
||||
socketio = SocketIO(app, async_method="eventlet")
|
||||
print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
||||
|
||||
def emit(*args, **kwargs):
|
||||
try:
|
||||
return _emit(*args, **kwargs)
|
||||
except AttributeError:
|
||||
return socketio.emit(*args, **kwargs)
|
||||
|
||||
# marshmallow/apispec setup
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from apispec.ext.marshmallow.field_converter import make_min_max_attributes
|
||||
from apispec_webframeworks.flask import FlaskPlugin
|
||||
from marshmallow import Schema, fields, validate
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict:
|
||||
# Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False
|
||||
attributes = {}
|
||||
min_list = [validator.min for validator in validators if validator.min is not None]
|
||||
max_list = [validator.max for validator in validators if validator.max is not None]
|
||||
min_inclusive_list = [getattr(validator, "min_inclusive", True) for validator in validators if validator.min is not None]
|
||||
max_inclusive_list = [getattr(validator, "max_inclusive", True) for validator in validators if validator.max is not None]
|
||||
if min_list:
|
||||
if min_attr == "minimum" and not min_inclusive_list[max(range(len(min_list)), key=min_list.__getitem__)]:
|
||||
min_attr = "exclusiveMinimum"
|
||||
attributes[min_attr] = max(min_list)
|
||||
if max_list:
|
||||
if min_attr == "maximum" and not max_inclusive_list[min(range(len(max_list)), key=max_list.__getitem__)]:
|
||||
min_attr = "exclusiveMaximum"
|
||||
attributes[max_attr] = min(max_list)
|
||||
return attributes
|
||||
make_min_max_attributes.__code__ = new_make_min_max_attributes.__code__
|
||||
|
||||
def api_format_docstring(f):
|
||||
f.__doc__ = eval('f"""{}"""'.format(f.__doc__))
|
||||
return f
|
||||
|
||||
def api_catch_out_of_memory_errors(f):
|
||||
@functools.wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if any (s in traceback.format_exc().lower() for s in ("out of memory", "not enough memory")):
|
||||
for line in reversed(traceback.format_exc().split("\n")):
|
||||
if any(s in line.lower() for s in ("out of memory", "not enough memory")) and line.count(":"):
|
||||
line = line.split(":", 1)[1]
|
||||
line = re.sub(r"\[.+?\] +data\.", "", line).strip()
|
||||
raise KoboldOutOfMemoryError("KoboldAI ran out of memory: " + line, type="out_of_memory.gpu.cuda" if "cuda out of memory" in line.lower() else "out_of_memory.gpu.hip" if "hip out of memory" in line.lower() else "out_of_memory.tpu.hbm" if "memory space hbm" in line.lower() else "out_of_memory.cpu.default_memory_allocator" if "defaultmemoryallocator" in line.lower() else "out_of_memory.unknown.unknown")
|
||||
raise KoboldOutOfMemoryError(type="out_of_memory.unknown.unknown")
|
||||
raise e
|
||||
return decorated
|
||||
|
||||
def api_schema_wrap(f):
|
||||
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
|
||||
assert inspect.isclass(input_schema) and issubclass(input_schema, Schema)
|
||||
f = api_format_docstring(f)
|
||||
f = api_catch_out_of_memory_errors(f)
|
||||
@functools.wraps(f)
|
||||
def decorated(*args, **Kwargs):
|
||||
body = request.get_json()
|
||||
schema = input_schema.from_dict(input_schema().load(body))
|
||||
response = f(schema)
|
||||
if not isinstance(response, Response):
|
||||
response = jsonify(response)
|
||||
return response
|
||||
return decorated
|
||||
|
||||
@app.errorhandler(HTTPException)
|
||||
def handler(e):
|
||||
return jsonify(detail={"type": "generic.error_" + str(e.code), "msg": str(e)}), e.code
|
||||
|
||||
class KoboldOutOfMemoryError(HTTPException):
|
||||
code = 507
|
||||
description = "KoboldAI ran out of memory."
|
||||
type = "out_of_memory.unknown"
|
||||
def __init__(self, *args, type=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if type is not None:
|
||||
self.type = type
|
||||
@app.errorhandler(KoboldOutOfMemoryError)
|
||||
def handler(e):
|
||||
return jsonify(detail={"type": e.type, "msg": e.description}), e.code
|
||||
|
||||
@app.errorhandler(ValidationError)
|
||||
def handler(e):
|
||||
return jsonify(detail=e.messages), 422
|
||||
|
||||
@app.errorhandler(NotImplementedError)
|
||||
def handler(e):
|
||||
return jsonify(detail={"type": "not_implemented", "msg": str(e).strip()}), 501
|
||||
|
||||
class KoboldAPISpec(APISpec):
|
||||
class KoboldFlaskPlugin(FlaskPlugin):
|
||||
def __init__(self, api: "KoboldAPISpec", *args, **kwargs):
|
||||
self._kobold_api_spec = api
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def path_helper(self, *args, **kwargs):
|
||||
return super().path_helper(*args, **kwargs)[len(self._kobold_api_spec._prefixes[0]):]
|
||||
|
||||
def __init__(self, *args, title: str = "KoboldAI API", openapi_version: str = "3.0.3", prefixes: List[str] = None, **kwargs):
|
||||
plugins = [KoboldAPISpec.KoboldFlaskPlugin(self), MarshmallowPlugin()]
|
||||
self._prefixes = prefixes if prefixes is not None else [""]
|
||||
super().__init__(*args, title=title, openapi_version=openapi_version, plugins=plugins, servers=[{"url": self._prefixes[0]}], **kwargs)
|
||||
for prefix in self._prefixes:
|
||||
app.route(prefix + "/docs", endpoint="~KoboldAPISpec~" + prefix + "/docs")(lambda: render_template("swagger-ui.html", url=self._prefixes[0] + "/openapi.json"))
|
||||
app.route(prefix + "/openapi.json", endpoint="~KoboldAPISpec~" + prefix + "/openapi.json")(lambda: jsonify(self.to_dict()))
|
||||
|
||||
def route(self, rule: str, methods=["GET"], **kwargs):
|
||||
__F = TypeVar("__F", bound=Callable[..., Any])
|
||||
def new_decorator(f: __F) -> __F:
|
||||
for prefix in self._prefixes:
|
||||
f = app.route(prefix + rule, methods=methods, **kwargs)(f)
|
||||
with app.test_request_context():
|
||||
self.path(view=f, **kwargs)
|
||||
return f
|
||||
return new_decorator
|
||||
|
||||
def get(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["GET"], **kwargs)
|
||||
|
||||
def post(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["POST"], **kwargs)
|
||||
|
||||
def put(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["PUT"], **kwargs)
|
||||
|
||||
def patch(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["PATCH"], **kwargs)
|
||||
|
||||
def delete(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["DELETE"], **kwargs)
|
||||
|
||||
api_v1 = KoboldAPISpec(
|
||||
version="1.0.0",
|
||||
prefixes=["/api/v1", "/api/latest"],
|
||||
)
|
||||
|
||||
#==================================================================#
|
||||
# Function to get model selection at startup
|
||||
#==================================================================#
|
||||
@ -1492,6 +1637,9 @@ def patch_transformers():
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
|
||||
if(vars.standalone):
|
||||
return scores
|
||||
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||
@ -1595,12 +1743,14 @@ def patch_transformers():
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
vars.generated_tkns += 1
|
||||
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||
if(not vars.standalone and vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
|
||||
if(vars.abort or vars.generated_tkns >= vars.genamt):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
return True
|
||||
if(vars.standalone):
|
||||
return False
|
||||
|
||||
assert input_ids.ndim == 2
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
@ -3767,6 +3917,97 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
|
||||
def apiactionsubmit_generate(txt, minimum, maximum):
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if not vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Submit input text to generator
|
||||
_genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, set())
|
||||
|
||||
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
|
||||
|
||||
# Clear CUDA cache again if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
del _genout
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return genout
|
||||
|
||||
def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum):
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if(vars.full_determinism):
|
||||
tpu_mtj_backend.set_rng_seed(vars.seed)
|
||||
|
||||
if not vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
|
||||
vars._actions = vars.actions
|
||||
vars._prompt = vars.prompt
|
||||
if(vars.dynamicscan):
|
||||
vars._actions = vars._actions.copy()
|
||||
|
||||
# Submit input text to generator
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer_static,
|
||||
np.uint32(txt),
|
||||
gen_len = maximum-minimum+1,
|
||||
temp=vars.temp,
|
||||
top_p=vars.top_p,
|
||||
top_k=vars.top_k,
|
||||
tfs=vars.tfs,
|
||||
typical=vars.typical,
|
||||
top_a=vars.top_a,
|
||||
numseqs=vars.numseqs,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
rpslope=vars.rep_pen_slope,
|
||||
rprange=vars.rep_pen_range,
|
||||
soft_embeddings=vars.sp,
|
||||
soft_tokens=soft_tokens,
|
||||
sampler_order=vars.sampler_order,
|
||||
)
|
||||
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(txt))) for txt in genout]
|
||||
|
||||
return genout
|
||||
|
||||
def apiactionsubmit(data, use_memory=False):
|
||||
if(vars.model == "Colab"):
|
||||
raise NotImplementedError("API generation is not supported in old Colab API mode.")
|
||||
elif(vars.model == "OAI"):
|
||||
raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.")
|
||||
elif(vars.model == "ReadOnly"):
|
||||
raise NotImplementedError("API generation is not supported in read-only mode; please load a model and then try again.")
|
||||
|
||||
if(vars.memory != "" and vars.memory[-1] != "\n"):
|
||||
mem = vars.memory + "\n"
|
||||
else:
|
||||
mem = vars.memory
|
||||
tokens = []
|
||||
if(use_memory):
|
||||
tokens += tokenizer.encode(utils.encodenewlines(mem))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):]
|
||||
tokens += tokenizer.encode(utils.encodenewlines(data))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):]
|
||||
tokens = tokenizer._koboldai_header + tokens
|
||||
minimum = len(tokens) + 1
|
||||
maximum = len(tokens) + vars.genamt
|
||||
|
||||
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
genout = apiactionsubmit_generate(tokens, minimum, maximum)
|
||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
||||
|
||||
genout = [applyoutputformatting(txt) for txt in genout]
|
||||
|
||||
return genout
|
||||
|
||||
#==================================================================#
|
||||
#
|
||||
#==================================================================#
|
||||
@ -4727,6 +4968,8 @@ def refresh_settings():
|
||||
# Sets the logical and display states for the AI Busy condition
|
||||
#==================================================================#
|
||||
def set_aibusy(state):
|
||||
if(vars.disable_set_aibusy):
|
||||
return
|
||||
if(state):
|
||||
vars.aibusy = True
|
||||
emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'}, broadcast=True)
|
||||
@ -6420,6 +6663,223 @@ def get_files_folders(starting_folder):
|
||||
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
|
||||
|
||||
|
||||
class BasicErrorSchema(Schema):
|
||||
msg: str = fields.String(required=True)
|
||||
type: str = fields.String(required=True)
|
||||
|
||||
class OutOfMemoryErrorSchema(Schema):
|
||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||
|
||||
api_out_of_memory_response = """507:
|
||||
description: Out of memory
|
||||
content:
|
||||
application/json:
|
||||
schema: OutOfMemoryErrorSchema
|
||||
examples:
|
||||
gpu.cuda:
|
||||
value:
|
||||
detail:
|
||||
msg: "KoboldAI ran out of memory: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
|
||||
type: out_of_memory.gpu.cuda
|
||||
gpu.hip:
|
||||
value:
|
||||
detail:
|
||||
msg: "KoboldAI ran out of memory: HIP out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
|
||||
type: out_of_memory.gpu.hip
|
||||
tpu.hbm:
|
||||
value:
|
||||
detail:
|
||||
msg: "KoboldAI ran out of memory: Compilation failed: Compilation failure: Ran out of memory in memory space hbm. Used 8.83G of 8.00G hbm. Exceeded hbm capacity by 848.88M."
|
||||
type: out_of_memory.tpu.hbm
|
||||
cpu.default_cpu_allocator:
|
||||
value:
|
||||
detail:
|
||||
msg: "KoboldAI ran out of memory: DefaultCPUAllocator: not enough memory: you tried to allocate 209715200 bytes."
|
||||
type: out_of_memory.cpu.default_cpu_allocator
|
||||
unknown.unknown:
|
||||
value:
|
||||
detail:
|
||||
msg: "KoboldAI ran out of memory."
|
||||
type: out_of_memory.unknown.unknown"""
|
||||
|
||||
class ValidationErrorSchema(Schema):
|
||||
detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String()), required=True)
|
||||
|
||||
api_validation_error_response = """422:
|
||||
description: Validation error
|
||||
content:
|
||||
application/json:
|
||||
schema: ValidationErrorSchema"""
|
||||
|
||||
class ServerBusyErrorSchema(Schema):
|
||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||
|
||||
api_server_busy_response = """503:
|
||||
description: Server is busy
|
||||
content:
|
||||
application/json:
|
||||
schema: ServerBusyErrorSchema
|
||||
example:
|
||||
detail:
|
||||
msg: Server is busy; please try again later.
|
||||
type: service_unavailable"""
|
||||
|
||||
class NotImplementedErrorSchema(Schema):
|
||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||
|
||||
api_not_implemented_response = """501:
|
||||
description: Not implemented
|
||||
content:
|
||||
application/json:
|
||||
schema: NotImplementedErrorSchema
|
||||
example:
|
||||
detail:
|
||||
msg: API generation is not supported in read-only mode; please load a model and then try again.
|
||||
type: not_implemented"""
|
||||
|
||||
class SamplerSettingsSchema(Schema):
|
||||
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
|
||||
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
|
||||
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
|
||||
top_k: Optional[int] = fields.Int(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
|
||||
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
|
||||
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
|
||||
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
|
||||
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
|
||||
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
|
||||
|
||||
class GenerationInputSchema(SamplerSettingsSchema):
|
||||
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
||||
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
||||
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."})
|
||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
||||
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
||||
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
||||
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."})
|
||||
frmtrmblln: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines in the output with one newline."})
|
||||
frmtrmspch: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes `#/@%{}+=~|\^<>` from the output."})
|
||||
singleline: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes everything after the first line of the output, including the newline."})
|
||||
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
|
||||
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
|
||||
|
||||
class GenerationResultSchema(Schema):
|
||||
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
|
||||
|
||||
class GenerationOutputSchema(Schema):
|
||||
results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."})
|
||||
|
||||
def _generate_text(body: GenerationInputSchema):
|
||||
if vars.aibusy or vars.genseqs:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
"type": "service_unavailable",
|
||||
"msg": "Server is busy; please try again later.",
|
||||
}}), mimetype="application/json", status=503))
|
||||
if body.use_story:
|
||||
raise NotImplementedError("use_story is not currently supported.")
|
||||
if body.use_world_info:
|
||||
raise NotImplementedError("use_world_info is not currently supported.")
|
||||
if body.use_userscripts:
|
||||
raise NotImplementedError("use_userscripts is not currently supported.")
|
||||
mapping = {
|
||||
"rep_pen": (vars, "rep_pen"),
|
||||
"rep_pen_range": (vars, "rep_pen_range"),
|
||||
"rep_pen_slope": (vars, "rep_pen_slope"),
|
||||
"top_k": (vars, "top_k"),
|
||||
"top_a": (vars, "top_a"),
|
||||
"top_p": (vars, "top_p"),
|
||||
"tfs": (vars, "tfs"),
|
||||
"typical": (vars, "typical"),
|
||||
"temperature": (vars, "temp"),
|
||||
"frmtadnsp": (vars.formatoptns, "@frmtadnsp"),
|
||||
"frmttriminc": (vars.formatoptns, "@frmttriminc"),
|
||||
"frmtrmblln": (vars.formatoptns, "@frmtrmblln"),
|
||||
"frmtrmspch": (vars.formatoptns, "@frmtrmspch"),
|
||||
"singleline": (vars.formatoptns, "@singleline"),
|
||||
"disable_input_formatting": (vars, "disable_input_formatting"),
|
||||
"disable_output_formatting": (vars, "disable_output_formatting"),
|
||||
"max_length": (vars, "genamt"),
|
||||
"n": (vars, "numseqs"),
|
||||
}
|
||||
saved_settings = {}
|
||||
set_aibusy(1)
|
||||
disable_set_aibusy = vars.disable_set_aibusy
|
||||
vars.disable_set_aibusy = True
|
||||
_standalone = vars.standalone
|
||||
vars.standalone = True
|
||||
for key, entry in mapping.items():
|
||||
if getattr(body, key, None) is not None:
|
||||
if entry[1].startswith("@"):
|
||||
saved_settings[key] = entry[0][entry[1][1:]]
|
||||
entry[0][entry[1][1:]] = getattr(body, key)
|
||||
else:
|
||||
saved_settings[key] = getattr(entry[0], entry[1])
|
||||
setattr(entry[0], entry[1], getattr(body, key))
|
||||
try:
|
||||
if getattr(body, "soft_prompt", None) is not None:
|
||||
if any(q in body.soft_prompt for q in ("/", "\\")):
|
||||
raise RuntimeError
|
||||
old_spfilename = vars.spfilename
|
||||
spRequest(body.soft_prompt)
|
||||
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
|
||||
output = {"results": [{"text": txt} for txt in genout]}
|
||||
finally:
|
||||
for key in saved_settings:
|
||||
entry = mapping[key]
|
||||
if getattr(body, key, None) is not None:
|
||||
if entry[1].startswith("@"):
|
||||
if entry[0][entry[1][1:]] == getattr(body, key):
|
||||
entry[0][entry[1][1:]] = saved_settings[key]
|
||||
else:
|
||||
if getattr(entry[0], entry[1]) == getattr(body, key):
|
||||
setattr(entry[0], entry[1], saved_settings[key])
|
||||
vars.disable_set_aibusy = disable_set_aibusy
|
||||
vars.standalone = _standalone
|
||||
if getattr(body, "soft_prompt", None) is not None:
|
||||
spRequest(old_spfilename)
|
||||
set_aibusy(0)
|
||||
return output
|
||||
|
||||
@api_v1.post("/generate")
|
||||
@api_schema_wrap
|
||||
def post_completion_standalone(body: GenerationInputSchema):
|
||||
r"""Generate text
|
||||
---
|
||||
post:
|
||||
description: |-2
|
||||
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
|
||||
|
||||
Unless otherwise specified, optional values default to the values in the KoboldAI GUI.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: GenerationInputSchema
|
||||
example:
|
||||
prompt: |-2
|
||||
Explosions of suspicious origin occur at AMNAT satellite-receiver stations from Turkey to Labrador as three high-level Canadian defense ministers vanish and then a couple of days later are photographed at a Volgograd bistro hoisting shots of Stolichnaya with Slavic bimbos on their knee.
|
||||
top_p: 0.9
|
||||
temperature: 0.5
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: GenerationOutputSchema
|
||||
example:
|
||||
results:
|
||||
- text: |-2
|
||||
It is later established that all of the cabinet members have died of old age.
|
||||
MEGAMATRIX becomes involved in the growing number of mass abductions and kidnappings. Many disappearances occur along highways in western Canada, usually when traffic has come to a standstill because of a stalled truck or snowstorm. One or two abducted individuals will be released within a day or so but never
|
||||
{api_validation_error_response}
|
||||
{api_not_implemented_response}
|
||||
{api_server_busy_response}
|
||||
{api_out_of_memory_response}
|
||||
"""
|
||||
return _generate_text(body)
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Final startup commands to launch Flask app
|
||||
|
Reference in New Issue
Block a user