Upload basic API with /generate POST endpoint

This commit is contained in:
vfbd
2022-08-08 02:27:48 -04:00
parent bd13a41eb7
commit 34c9535667
15 changed files with 1672 additions and 7 deletions

View File

@ -36,8 +36,9 @@ import itertools
import bisect
import functools
import traceback
import inspect
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
import requests
import html
@ -352,6 +353,10 @@ class vars:
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
revision = None
output_streaming = False
standalone = False
disable_set_aibusy = False
disable_input_formatting = False
disable_output_formatting = False
token_stream_queue = [] # Queue for the token streaming
utils.vars = vars
@ -372,9 +377,11 @@ log.setLevel(logging.ERROR)
# Start flask & SocketIO
print("{0}Initializing Flask... {1}".format(colors.PURPLE, colors.END), end="")
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session
from flask_socketio import SocketIO, emit
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session, jsonify, abort
from flask_socketio import SocketIO
from flask_socketio import emit as _emit
from flask_session import Session
from werkzeug.exceptions import HTTPException, ServiceUnavailable
import secrets
app = Flask(__name__, root_path=os.getcwd())
app.secret_key = secrets.token_hex()
@ -384,6 +391,144 @@ Session(app)
socketio = SocketIO(app, async_method="eventlet")
print("{0}OK!{1}".format(colors.GREEN, colors.END))
def emit(*args, **kwargs):
try:
return _emit(*args, **kwargs)
except AttributeError:
return socketio.emit(*args, **kwargs)
# marshmallow/apispec setup
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec.ext.marshmallow.field_converter import make_min_max_attributes
from apispec_webframeworks.flask import FlaskPlugin
from marshmallow import Schema, fields, validate
from marshmallow.exceptions import ValidationError
def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict:
# Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False
attributes = {}
min_list = [validator.min for validator in validators if validator.min is not None]
max_list = [validator.max for validator in validators if validator.max is not None]
min_inclusive_list = [getattr(validator, "min_inclusive", True) for validator in validators if validator.min is not None]
max_inclusive_list = [getattr(validator, "max_inclusive", True) for validator in validators if validator.max is not None]
if min_list:
if min_attr == "minimum" and not min_inclusive_list[max(range(len(min_list)), key=min_list.__getitem__)]:
min_attr = "exclusiveMinimum"
attributes[min_attr] = max(min_list)
if max_list:
if min_attr == "maximum" and not max_inclusive_list[min(range(len(max_list)), key=max_list.__getitem__)]:
min_attr = "exclusiveMaximum"
attributes[max_attr] = min(max_list)
return attributes
make_min_max_attributes.__code__ = new_make_min_max_attributes.__code__
def api_format_docstring(f):
f.__doc__ = eval('f"""{}"""'.format(f.__doc__))
return f
def api_catch_out_of_memory_errors(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
return f(*args, **kwargs)
except Exception as e:
if any (s in traceback.format_exc().lower() for s in ("out of memory", "not enough memory")):
for line in reversed(traceback.format_exc().split("\n")):
if any(s in line.lower() for s in ("out of memory", "not enough memory")) and line.count(":"):
line = line.split(":", 1)[1]
line = re.sub(r"\[.+?\] +data\.", "", line).strip()
raise KoboldOutOfMemoryError("KoboldAI ran out of memory: " + line, type="out_of_memory.gpu.cuda" if "cuda out of memory" in line.lower() else "out_of_memory.gpu.hip" if "hip out of memory" in line.lower() else "out_of_memory.tpu.hbm" if "memory space hbm" in line.lower() else "out_of_memory.cpu.default_memory_allocator" if "defaultmemoryallocator" in line.lower() else "out_of_memory.unknown.unknown")
raise KoboldOutOfMemoryError(type="out_of_memory.unknown.unknown")
raise e
return decorated
def api_schema_wrap(f):
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
assert inspect.isclass(input_schema) and issubclass(input_schema, Schema)
f = api_format_docstring(f)
f = api_catch_out_of_memory_errors(f)
@functools.wraps(f)
def decorated(*args, **Kwargs):
body = request.get_json()
schema = input_schema.from_dict(input_schema().load(body))
response = f(schema)
if not isinstance(response, Response):
response = jsonify(response)
return response
return decorated
@app.errorhandler(HTTPException)
def handler(e):
return jsonify(detail={"type": "generic.error_" + str(e.code), "msg": str(e)}), e.code
class KoboldOutOfMemoryError(HTTPException):
code = 507
description = "KoboldAI ran out of memory."
type = "out_of_memory.unknown"
def __init__(self, *args, type=None, **kwargs):
super().__init__(*args, **kwargs)
if type is not None:
self.type = type
@app.errorhandler(KoboldOutOfMemoryError)
def handler(e):
return jsonify(detail={"type": e.type, "msg": e.description}), e.code
@app.errorhandler(ValidationError)
def handler(e):
return jsonify(detail=e.messages), 422
@app.errorhandler(NotImplementedError)
def handler(e):
return jsonify(detail={"type": "not_implemented", "msg": str(e).strip()}), 501
class KoboldAPISpec(APISpec):
class KoboldFlaskPlugin(FlaskPlugin):
def __init__(self, api: "KoboldAPISpec", *args, **kwargs):
self._kobold_api_spec = api
super().__init__(*args, **kwargs)
def path_helper(self, *args, **kwargs):
return super().path_helper(*args, **kwargs)[len(self._kobold_api_spec._prefixes[0]):]
def __init__(self, *args, title: str = "KoboldAI API", openapi_version: str = "3.0.3", prefixes: List[str] = None, **kwargs):
plugins = [KoboldAPISpec.KoboldFlaskPlugin(self), MarshmallowPlugin()]
self._prefixes = prefixes if prefixes is not None else [""]
super().__init__(*args, title=title, openapi_version=openapi_version, plugins=plugins, servers=[{"url": self._prefixes[0]}], **kwargs)
for prefix in self._prefixes:
app.route(prefix + "/docs", endpoint="~KoboldAPISpec~" + prefix + "/docs")(lambda: render_template("swagger-ui.html", url=self._prefixes[0] + "/openapi.json"))
app.route(prefix + "/openapi.json", endpoint="~KoboldAPISpec~" + prefix + "/openapi.json")(lambda: jsonify(self.to_dict()))
def route(self, rule: str, methods=["GET"], **kwargs):
__F = TypeVar("__F", bound=Callable[..., Any])
def new_decorator(f: __F) -> __F:
for prefix in self._prefixes:
f = app.route(prefix + rule, methods=methods, **kwargs)(f)
with app.test_request_context():
self.path(view=f, **kwargs)
return f
return new_decorator
def get(self, rule: str, **kwargs):
return self.route(rule, methods=["GET"], **kwargs)
def post(self, rule: str, **kwargs):
return self.route(rule, methods=["POST"], **kwargs)
def put(self, rule: str, **kwargs):
return self.route(rule, methods=["PUT"], **kwargs)
def patch(self, rule: str, **kwargs):
return self.route(rule, methods=["PATCH"], **kwargs)
def delete(self, rule: str, **kwargs):
return self.route(rule, methods=["DELETE"], **kwargs)
api_v1 = KoboldAPISpec(
version="1.0.0",
prefixes=["/api/v1", "/api/latest"],
)
#==================================================================#
# Function to get model selection at startup
#==================================================================#
@ -1492,6 +1637,9 @@ def patch_transformers():
self.regeneration_required = False
self.halt = False
if(vars.standalone):
return scores
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
@ -1595,12 +1743,14 @@ def patch_transformers():
**kwargs,
) -> bool:
vars.generated_tkns += 1
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
if(not vars.standalone and vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
if(vars.abort or vars.generated_tkns >= vars.genamt):
self.regeneration_required = False
self.halt = False
return True
if(vars.standalone):
return False
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
@ -3767,6 +3917,97 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
break
def apiactionsubmit_generate(txt, minimum, maximum):
vars.generated_tkns = 0
if not vars.quiet:
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
# Clear CUDA cache if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
gc.collect()
torch.cuda.empty_cache()
# Submit input text to generator
_genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, set())
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
# Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
del _genout
gc.collect()
torch.cuda.empty_cache()
return genout
def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum):
vars.generated_tkns = 0
if(vars.full_determinism):
tpu_mtj_backend.set_rng_seed(vars.seed)
if not vars.quiet:
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
typical=vars.typical,
top_a=vars.top_a,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
rpslope=vars.rep_pen_slope,
rprange=vars.rep_pen_range,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
sampler_order=vars.sampler_order,
)
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(txt))) for txt in genout]
return genout
def apiactionsubmit(data, use_memory=False):
if(vars.model == "Colab"):
raise NotImplementedError("API generation is not supported in old Colab API mode.")
elif(vars.model == "OAI"):
raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.")
elif(vars.model == "ReadOnly"):
raise NotImplementedError("API generation is not supported in read-only mode; please load a model and then try again.")
if(vars.memory != "" and vars.memory[-1] != "\n"):
mem = vars.memory + "\n"
else:
mem = vars.memory
tokens = []
if(use_memory):
tokens += tokenizer.encode(utils.encodenewlines(mem))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):]
tokens += tokenizer.encode(utils.encodenewlines(data))[-(vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(tokens)):]
tokens = tokenizer._koboldai_header + tokens
minimum = len(tokens) + 1
maximum = len(tokens) + vars.genamt
if(not vars.use_colab_tpu and vars.model not in ["Colab", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
genout = apiactionsubmit_generate(tokens, minimum, maximum)
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
genout = [applyoutputformatting(txt) for txt in genout]
return genout
#==================================================================#
#
#==================================================================#
@ -4727,6 +4968,8 @@ def refresh_settings():
# Sets the logical and display states for the AI Busy condition
#==================================================================#
def set_aibusy(state):
if(vars.disable_set_aibusy):
return
if(state):
vars.aibusy = True
emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'}, broadcast=True)
@ -6420,6 +6663,223 @@ def get_files_folders(starting_folder):
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
class BasicErrorSchema(Schema):
msg: str = fields.String(required=True)
type: str = fields.String(required=True)
class OutOfMemoryErrorSchema(Schema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_out_of_memory_response = """507:
description: Out of memory
content:
application/json:
schema: OutOfMemoryErrorSchema
examples:
gpu.cuda:
value:
detail:
msg: "KoboldAI ran out of memory: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
type: out_of_memory.gpu.cuda
gpu.hip:
value:
detail:
msg: "KoboldAI ran out of memory: HIP out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
type: out_of_memory.gpu.hip
tpu.hbm:
value:
detail:
msg: "KoboldAI ran out of memory: Compilation failed: Compilation failure: Ran out of memory in memory space hbm. Used 8.83G of 8.00G hbm. Exceeded hbm capacity by 848.88M."
type: out_of_memory.tpu.hbm
cpu.default_cpu_allocator:
value:
detail:
msg: "KoboldAI ran out of memory: DefaultCPUAllocator: not enough memory: you tried to allocate 209715200 bytes."
type: out_of_memory.cpu.default_cpu_allocator
unknown.unknown:
value:
detail:
msg: "KoboldAI ran out of memory."
type: out_of_memory.unknown.unknown"""
class ValidationErrorSchema(Schema):
detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String()), required=True)
api_validation_error_response = """422:
description: Validation error
content:
application/json:
schema: ValidationErrorSchema"""
class ServerBusyErrorSchema(Schema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_server_busy_response = """503:
description: Server is busy
content:
application/json:
schema: ServerBusyErrorSchema
example:
detail:
msg: Server is busy; please try again later.
type: service_unavailable"""
class NotImplementedErrorSchema(Schema):
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
api_not_implemented_response = """501:
description: Not implemented
content:
application/json:
schema: NotImplementedErrorSchema
example:
detail:
msg: API generation is not supported in read-only mode; please load a model and then try again.
type: not_implemented"""
class SamplerSettingsSchema(Schema):
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
top_k: Optional[int] = fields.Int(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
class GenerationInputSchema(SamplerSettingsSchema):
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."})
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."})
frmtrmblln: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines in the output with one newline."})
frmtrmspch: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes `#/@%{}+=~|\^<>` from the output."})
singleline: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes everything after the first line of the output, including the newline."})
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
class GenerationResultSchema(Schema):
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
class GenerationOutputSchema(Schema):
results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."})
def _generate_text(body: GenerationInputSchema):
if vars.aibusy or vars.genseqs:
abort(Response(json.dumps({"detail": {
"type": "service_unavailable",
"msg": "Server is busy; please try again later.",
}}), mimetype="application/json", status=503))
if body.use_story:
raise NotImplementedError("use_story is not currently supported.")
if body.use_world_info:
raise NotImplementedError("use_world_info is not currently supported.")
if body.use_userscripts:
raise NotImplementedError("use_userscripts is not currently supported.")
mapping = {
"rep_pen": (vars, "rep_pen"),
"rep_pen_range": (vars, "rep_pen_range"),
"rep_pen_slope": (vars, "rep_pen_slope"),
"top_k": (vars, "top_k"),
"top_a": (vars, "top_a"),
"top_p": (vars, "top_p"),
"tfs": (vars, "tfs"),
"typical": (vars, "typical"),
"temperature": (vars, "temp"),
"frmtadnsp": (vars.formatoptns, "@frmtadnsp"),
"frmttriminc": (vars.formatoptns, "@frmttriminc"),
"frmtrmblln": (vars.formatoptns, "@frmtrmblln"),
"frmtrmspch": (vars.formatoptns, "@frmtrmspch"),
"singleline": (vars.formatoptns, "@singleline"),
"disable_input_formatting": (vars, "disable_input_formatting"),
"disable_output_formatting": (vars, "disable_output_formatting"),
"max_length": (vars, "genamt"),
"n": (vars, "numseqs"),
}
saved_settings = {}
set_aibusy(1)
disable_set_aibusy = vars.disable_set_aibusy
vars.disable_set_aibusy = True
_standalone = vars.standalone
vars.standalone = True
for key, entry in mapping.items():
if getattr(body, key, None) is not None:
if entry[1].startswith("@"):
saved_settings[key] = entry[0][entry[1][1:]]
entry[0][entry[1][1:]] = getattr(body, key)
else:
saved_settings[key] = getattr(entry[0], entry[1])
setattr(entry[0], entry[1], getattr(body, key))
try:
if getattr(body, "soft_prompt", None) is not None:
if any(q in body.soft_prompt for q in ("/", "\\")):
raise RuntimeError
old_spfilename = vars.spfilename
spRequest(body.soft_prompt)
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory)
output = {"results": [{"text": txt} for txt in genout]}
finally:
for key in saved_settings:
entry = mapping[key]
if getattr(body, key, None) is not None:
if entry[1].startswith("@"):
if entry[0][entry[1][1:]] == getattr(body, key):
entry[0][entry[1][1:]] = saved_settings[key]
else:
if getattr(entry[0], entry[1]) == getattr(body, key):
setattr(entry[0], entry[1], saved_settings[key])
vars.disable_set_aibusy = disable_set_aibusy
vars.standalone = _standalone
if getattr(body, "soft_prompt", None) is not None:
spRequest(old_spfilename)
set_aibusy(0)
return output
@api_v1.post("/generate")
@api_schema_wrap
def post_completion_standalone(body: GenerationInputSchema):
r"""Generate text
---
post:
description: |-2
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
Unless otherwise specified, optional values default to the values in the KoboldAI GUI.
requestBody:
required: true
content:
application/json:
schema: GenerationInputSchema
example:
prompt: |-2
Explosions of suspicious origin occur at AMNAT satellite-receiver stations from Turkey to Labrador as three high-level Canadian defense ministers vanish and then a couple of days later are photographed at a Volgograd bistro hoisting shots of Stolichnaya with Slavic bimbos on their knee.
top_p: 0.9
temperature: 0.5
responses:
200:
description: Successful request
content:
application/json:
schema: GenerationOutputSchema
example:
results:
- text: |-2
It is later established that all of the cabinet members have died of old age.
MEGAMATRIX becomes involved in the growing number of mass abductions and kidnappings. Many disappearances occur along highways in western Canada, usually when traffic has come to a standstill because of a stalled truck or snowstorm. One or two abducted individuals will be released within a day or so but never
{api_validation_error_response}
{api_not_implemented_response}
{api_server_busy_response}
{api_out_of_memory_response}
"""
return _generate_text(body)
#==================================================================#
# Final startup commands to launch Flask app