mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-20 05:30:57 +01:00
Add more endpoints
This commit is contained in:
parent
a93087aecd
commit
1f629ee254
539
aiserver.py
539
aiserver.py
@ -448,15 +448,22 @@ def api_catch_out_of_memory_errors(f):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
def api_schema_wrap(f):
|
def api_schema_wrap(f):
|
||||||
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
|
try:
|
||||||
assert inspect.isclass(input_schema) and issubclass(input_schema, Schema)
|
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
|
||||||
|
except:
|
||||||
|
HAS_SCHEMA = False
|
||||||
|
else:
|
||||||
|
HAS_SCHEMA = inspect.isclass(input_schema) and issubclass(input_schema, Schema)
|
||||||
f = api_format_docstring(f)
|
f = api_format_docstring(f)
|
||||||
f = api_catch_out_of_memory_errors(f)
|
f = api_catch_out_of_memory_errors(f)
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def decorated(*args, **Kwargs):
|
def decorated(*args, **Kwargs):
|
||||||
body = request.get_json()
|
if HAS_SCHEMA:
|
||||||
schema = input_schema.from_dict(input_schema().load(body))
|
body = request.get_json()
|
||||||
response = f(schema)
|
schema = input_schema.from_dict(input_schema().load(body))
|
||||||
|
response = f(schema)
|
||||||
|
else:
|
||||||
|
response = f()
|
||||||
if not isinstance(response, Response):
|
if not isinstance(response, Response):
|
||||||
response = jsonify(response)
|
response = jsonify(response)
|
||||||
return response
|
return response
|
||||||
@ -531,9 +538,16 @@ class KoboldAPISpec(APISpec):
|
|||||||
def delete(self, rule: str, **kwargs):
|
def delete(self, rule: str, **kwargs):
|
||||||
return self.route(rule, methods=["DELETE"], **kwargs)
|
return self.route(rule, methods=["DELETE"], **kwargs)
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
{"name": "generate", "description": "Text generation endpoints"},
|
||||||
|
{"name": "story", "description": "Endpoints for managing the story in the KoboldAI GUI"},
|
||||||
|
{"name": "config", "description": "Allows you to get/set various setting values"},
|
||||||
|
]
|
||||||
|
|
||||||
api_v1 = KoboldAPISpec(
|
api_v1 = KoboldAPISpec(
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
prefixes=["/api/v1", "/api/latest"],
|
prefixes=["/api/v1", "/api/latest"],
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -3763,9 +3777,9 @@ def check_for_backend_compilation():
|
|||||||
break
|
break
|
||||||
vars.checking = False
|
vars.checking = False
|
||||||
|
|
||||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False):
|
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
|
||||||
# Ignore new submissions if the AI is currently busy
|
# Ignore new submissions if the AI is currently busy
|
||||||
if(vars.aibusy):
|
if(not vars.standalone and vars.aibusy):
|
||||||
return
|
return
|
||||||
|
|
||||||
while(True):
|
while(True):
|
||||||
@ -3797,20 +3811,21 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
|
|
||||||
if(not vars.gamestarted):
|
if(not vars.gamestarted):
|
||||||
vars.submission = data
|
vars.submission = data
|
||||||
execute_inmod()
|
if(not no_generate):
|
||||||
|
execute_inmod()
|
||||||
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
||||||
data = vars.submission
|
data = vars.submission
|
||||||
if(not force_submit and len(data.strip()) == 0):
|
if(not force_submit and len(data.strip()) == 0):
|
||||||
assert False
|
assert False
|
||||||
# Start the game
|
# Start the game
|
||||||
vars.gamestarted = True
|
vars.gamestarted = True
|
||||||
if(not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
|
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
|
||||||
# Save this first action as the prompt
|
# Save this first action as the prompt
|
||||||
vars.prompt = data
|
vars.prompt = data
|
||||||
# Clear the startup text from game screen
|
# Clear the startup text from game screen
|
||||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
|
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
|
||||||
calcsubmit(data) # Run the first action through the generator
|
calcsubmit(data) # Run the first action through the generator
|
||||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
@ -3822,7 +3837,8 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
vars.prompt = data if len(data) > 0 else '"'
|
vars.prompt = data if len(data) > 0 else '"'
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||||
execute_outmod()
|
if(not no_generate):
|
||||||
|
execute_outmod()
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
genout = []
|
genout = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
@ -3856,7 +3872,8 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
if(vars.actionmode == 0):
|
if(vars.actionmode == 0):
|
||||||
data = applyinputformatting(data)
|
data = applyinputformatting(data)
|
||||||
vars.submission = data
|
vars.submission = data
|
||||||
execute_inmod()
|
if(not no_generate):
|
||||||
|
execute_inmod()
|
||||||
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
||||||
data = vars.submission
|
data = vars.submission
|
||||||
# Dont append submission if it's a blank/continue action
|
# Dont append submission if it's a blank/continue action
|
||||||
@ -3886,7 +3903,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
update_story_chunk('last')
|
update_story_chunk('last')
|
||||||
send_debug()
|
send_debug()
|
||||||
|
|
||||||
if(not vars.noai and vars.lua_koboldbridge.generating):
|
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating):
|
||||||
# Off to the tokenizer!
|
# Off to the tokenizer!
|
||||||
calcsubmit(data)
|
calcsubmit(data)
|
||||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||||
@ -3897,23 +3914,24 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||||||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for i in range(vars.numseqs):
|
if(not no_generate):
|
||||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
for i in range(vars.numseqs):
|
||||||
execute_outmod()
|
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
execute_outmod()
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
genout = []
|
genout = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1] if not no_generate else ""})
|
||||||
assert type(genout[-1]["generated_text"]) is str
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
|
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
@ -4011,8 +4029,6 @@ def apiactionsubmit(data, use_memory=False):
|
|||||||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||||
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
||||||
|
|
||||||
genout = [applyoutputformatting(txt) for txt in genout]
|
|
||||||
|
|
||||||
return genout
|
return genout
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@ -6676,10 +6692,22 @@ def get_files_folders(starting_folder):
|
|||||||
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
|
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(KoboldSchema):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BasicTextResultInnerSchema(KoboldSchema):
|
||||||
|
text: str = fields.String(required=True)
|
||||||
|
|
||||||
|
class BasicTextResultSchema(KoboldSchema):
|
||||||
|
result: BasicTextResultInnerSchema = fields.Nested(BasicTextResultInnerSchema)
|
||||||
|
|
||||||
class BasicErrorSchema(KoboldSchema):
|
class BasicErrorSchema(KoboldSchema):
|
||||||
msg: str = fields.String(required=True)
|
msg: str = fields.String(required=True)
|
||||||
type: str = fields.String(required=True)
|
type: str = fields.String(required=True)
|
||||||
|
|
||||||
|
class StoryEmptyErrorSchema(KoboldSchema):
|
||||||
|
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||||
|
|
||||||
class OutOfMemoryErrorSchema(KoboldSchema):
|
class OutOfMemoryErrorSchema(KoboldSchema):
|
||||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||||
|
|
||||||
@ -6754,7 +6782,7 @@ class SamplerSettingsSchema(KoboldSchema):
|
|||||||
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
|
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
|
||||||
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
|
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
|
||||||
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
|
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
|
||||||
top_k: Optional[int] = fields.Int(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
|
top_k: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
|
||||||
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
|
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
|
||||||
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
|
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
|
||||||
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
|
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
|
||||||
@ -6781,7 +6809,8 @@ class GenerationInputSchema(SamplerSettingsSchema):
|
|||||||
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||||
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
|
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
|
||||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
max_length: int = fields.Integer(validate=validate.Range(min=1, max=512), metadata={"description": "Number of tokens to generate."})
|
||||||
|
max_context_length: int = fields.Integer(validate=validate.Range(min=512, max=2048), metadata={"description": "Maximum number of tokens to send to the model."})
|
||||||
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
||||||
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
||||||
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."})
|
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."})
|
||||||
@ -6800,8 +6829,8 @@ class GenerationOutputSchema(KoboldSchema):
|
|||||||
def _generate_text(body: GenerationInputSchema):
|
def _generate_text(body: GenerationInputSchema):
|
||||||
if vars.aibusy or vars.genseqs:
|
if vars.aibusy or vars.genseqs:
|
||||||
abort(Response(json.dumps({"detail": {
|
abort(Response(json.dumps({"detail": {
|
||||||
"type": "service_unavailable",
|
|
||||||
"msg": "Server is busy; please try again later.",
|
"msg": "Server is busy; please try again later.",
|
||||||
|
"type": "service_unavailable",
|
||||||
}}), mimetype="application/json", status=503))
|
}}), mimetype="application/json", status=503))
|
||||||
if body.use_story:
|
if body.use_story:
|
||||||
raise NotImplementedError("use_story is not currently supported.")
|
raise NotImplementedError("use_story is not currently supported.")
|
||||||
@ -6810,24 +6839,25 @@ def _generate_text(body: GenerationInputSchema):
|
|||||||
if body.use_userscripts:
|
if body.use_userscripts:
|
||||||
raise NotImplementedError("use_userscripts is not currently supported.")
|
raise NotImplementedError("use_userscripts is not currently supported.")
|
||||||
mapping = {
|
mapping = {
|
||||||
"rep_pen": (vars, "rep_pen"),
|
"rep_pen": ("vars", "rep_pen"),
|
||||||
"rep_pen_range": (vars, "rep_pen_range"),
|
"rep_pen_range": ("vars", "rep_pen_range"),
|
||||||
"rep_pen_slope": (vars, "rep_pen_slope"),
|
"rep_pen_slope": ("vars", "rep_pen_slope"),
|
||||||
"top_k": (vars, "top_k"),
|
"top_k": ("vars", "top_k"),
|
||||||
"top_a": (vars, "top_a"),
|
"top_a": ("vars", "top_a"),
|
||||||
"top_p": (vars, "top_p"),
|
"top_p": ("vars", "top_p"),
|
||||||
"tfs": (vars, "tfs"),
|
"tfs": ("vars", "tfs"),
|
||||||
"typical": (vars, "typical"),
|
"typical": ("vars", "typical"),
|
||||||
"temperature": (vars, "temp"),
|
"temperature": ("vars", "temp"),
|
||||||
"frmtadnsp": (vars.formatoptns, "@frmtadnsp"),
|
"frmtadnsp": ("vars.formatoptns", "@frmtadnsp"),
|
||||||
"frmttriminc": (vars.formatoptns, "@frmttriminc"),
|
"frmttriminc": ("vars.formatoptns", "@frmttriminc"),
|
||||||
"frmtrmblln": (vars.formatoptns, "@frmtrmblln"),
|
"frmtrmblln": ("vars.formatoptns", "@frmtrmblln"),
|
||||||
"frmtrmspch": (vars.formatoptns, "@frmtrmspch"),
|
"frmtrmspch": ("vars.formatoptns", "@frmtrmspch"),
|
||||||
"singleline": (vars.formatoptns, "@singleline"),
|
"singleline": ("vars.formatoptns", "@singleline"),
|
||||||
"disable_input_formatting": (vars, "disable_input_formatting"),
|
"disable_input_formatting": ("vars", "disable_input_formatting"),
|
||||||
"disable_output_formatting": (vars, "disable_output_formatting"),
|
"disable_output_formatting": ("vars", "disable_output_formatting"),
|
||||||
"max_length": (vars, "genamt"),
|
"max_length": ("vars", "genamt"),
|
||||||
"n": (vars, "numseqs"),
|
"max_context_length": ("vars", "max_length"),
|
||||||
|
"n": ("vars", "numseqs"),
|
||||||
}
|
}
|
||||||
saved_settings = {}
|
saved_settings = {}
|
||||||
set_aibusy(1)
|
set_aibusy(1)
|
||||||
@ -6836,13 +6866,15 @@ def _generate_text(body: GenerationInputSchema):
|
|||||||
_standalone = vars.standalone
|
_standalone = vars.standalone
|
||||||
vars.standalone = True
|
vars.standalone = True
|
||||||
for key, entry in mapping.items():
|
for key, entry in mapping.items():
|
||||||
|
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
|
||||||
if getattr(body, key, None) is not None:
|
if getattr(body, key, None) is not None:
|
||||||
if entry[1].startswith("@"):
|
if entry[1].startswith("@"):
|
||||||
saved_settings[key] = entry[0][entry[1][1:]]
|
saved_settings[key] = obj[entry[1][1:]]
|
||||||
entry[0][entry[1][1:]] = getattr(body, key)
|
obj[entry[1][1:]] = getattr(body, key)
|
||||||
|
print(entry[1][1:], obj[entry[1][1:]])
|
||||||
else:
|
else:
|
||||||
saved_settings[key] = getattr(entry[0], entry[1])
|
saved_settings[key] = getattr(obj, entry[1])
|
||||||
setattr(entry[0], entry[1], getattr(body, key))
|
setattr(obj, entry[1], getattr(body, key))
|
||||||
try:
|
try:
|
||||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||||
if any(q in body.soft_prompt for q in ("/", "\\")):
|
if any(q in body.soft_prompt for q in ("/", "\\")):
|
||||||
@ -6854,13 +6886,14 @@ def _generate_text(body: GenerationInputSchema):
|
|||||||
finally:
|
finally:
|
||||||
for key in saved_settings:
|
for key in saved_settings:
|
||||||
entry = mapping[key]
|
entry = mapping[key]
|
||||||
|
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
|
||||||
if getattr(body, key, None) is not None:
|
if getattr(body, key, None) is not None:
|
||||||
if entry[1].startswith("@"):
|
if entry[1].startswith("@"):
|
||||||
if entry[0][entry[1][1:]] == getattr(body, key):
|
if obj[entry[1][1:]] == getattr(body, key):
|
||||||
entry[0][entry[1][1:]] = saved_settings[key]
|
obj[entry[1][1:]] = saved_settings[key]
|
||||||
else:
|
else:
|
||||||
if getattr(entry[0], entry[1]) == getattr(body, key):
|
if getattr(obj, entry[1]) == getattr(body, key):
|
||||||
setattr(entry[0], entry[1], saved_settings[key])
|
setattr(obj, entry[1], saved_settings[key])
|
||||||
vars.disable_set_aibusy = disable_set_aibusy
|
vars.disable_set_aibusy = disable_set_aibusy
|
||||||
vars.standalone = _standalone
|
vars.standalone = _standalone
|
||||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||||
@ -6871,9 +6904,11 @@ def _generate_text(body: GenerationInputSchema):
|
|||||||
@api_v1.post("/generate")
|
@api_v1.post("/generate")
|
||||||
@api_schema_wrap
|
@api_schema_wrap
|
||||||
def post_completion_standalone(body: GenerationInputSchema):
|
def post_completion_standalone(body: GenerationInputSchema):
|
||||||
r"""Generate text
|
"""---
|
||||||
---
|
|
||||||
post:
|
post:
|
||||||
|
summary: Generate text
|
||||||
|
tags:
|
||||||
|
- generate
|
||||||
description: |-2
|
description: |-2
|
||||||
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
|
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
|
||||||
|
|
||||||
@ -6909,6 +6944,404 @@ def post_completion_standalone(body: GenerationInputSchema):
|
|||||||
return _generate_text(body)
|
return _generate_text(body)
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_validator(prompt: str):
|
||||||
|
if len(prompt.strip()) == 0:
|
||||||
|
raise ValidationError("String does not match expected pattern.")
|
||||||
|
|
||||||
|
class SubmissionInputSchema(KoboldSchema):
|
||||||
|
prompt: str = fields.String(required=True, validate=prompt_validator, metadata={"pattern": r"^.*\S.*$", "description": "This is the submission."})
|
||||||
|
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
|
||||||
|
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
|
||||||
|
|
||||||
|
@api_v1.post("/story/end")
|
||||||
|
@api_schema_wrap
|
||||||
|
def post_story_end(body: SubmissionInputSchema):
|
||||||
|
"""---
|
||||||
|
post:
|
||||||
|
summary: Add an action to the end of the story
|
||||||
|
tags:
|
||||||
|
- story
|
||||||
|
description: |-2
|
||||||
|
Inserts a single action at the end of the story in the KoboldAI GUI without generating text.
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SubmissionInputSchema
|
||||||
|
example:
|
||||||
|
prompt: |-2
|
||||||
|
This is some text to put at the end of the story.
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: EmptySchema
|
||||||
|
{api_validation_error_response}
|
||||||
|
{api_server_busy_response}
|
||||||
|
"""
|
||||||
|
if vars.aibusy or vars.genseqs:
|
||||||
|
abort(Response(json.dumps({"detail": {
|
||||||
|
"msg": "Server is busy; please try again later.",
|
||||||
|
"type": "service_unavailable",
|
||||||
|
}}), mimetype="application/json", status=503))
|
||||||
|
set_aibusy(1)
|
||||||
|
disable_set_aibusy = vars.disable_set_aibusy
|
||||||
|
vars.disable_set_aibusy = True
|
||||||
|
_standalone = vars.standalone
|
||||||
|
vars.standalone = True
|
||||||
|
numseqs = vars.numseqs
|
||||||
|
vars.numseqs = 1
|
||||||
|
try:
|
||||||
|
actionsubmit(body.prompt, force_submit=True, no_generate=True)
|
||||||
|
finally:
|
||||||
|
vars.disable_set_aibusy = disable_set_aibusy
|
||||||
|
vars.standalone = _standalone
|
||||||
|
vars.numseqs = numseqs
|
||||||
|
set_aibusy(0)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1.get("/story/end")
|
||||||
|
@api_schema_wrap
|
||||||
|
def get_story_end():
|
||||||
|
"""---
|
||||||
|
get:
|
||||||
|
summary: Retrieve the last action of the story
|
||||||
|
tags:
|
||||||
|
- story
|
||||||
|
description: |-2
|
||||||
|
Returns the last action of the story in the KoboldAI GUI.
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: BasicTextResultSchema
|
||||||
|
510:
|
||||||
|
description: Story is empty
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: StoryEmptyErrorSchema
|
||||||
|
example:
|
||||||
|
detail:
|
||||||
|
msg: Could not retrieve the last action of the story because the story is empty.
|
||||||
|
type: story_empty
|
||||||
|
"""
|
||||||
|
if not vars.gamestarted:
|
||||||
|
abort(Response(json.dumps({"detail": {
|
||||||
|
"msg": "Could not retrieve the last action of the story because the story is empty.",
|
||||||
|
"type": "story_empty",
|
||||||
|
}}), mimetype="application/json", status=510))
|
||||||
|
if len(vars.actions) == 0:
|
||||||
|
return {"result": {"text": vars.prompt}}
|
||||||
|
return {"result": {"text": vars.actions[vars.actions.get_last_key()]}}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_f_get(obj, _var_name, _name, _schema, _example_yaml_value):
|
||||||
|
def f_get():
|
||||||
|
"""---
|
||||||
|
get:
|
||||||
|
summary: Retrieve the current {} setting value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: {}
|
||||||
|
example:
|
||||||
|
value: {}
|
||||||
|
"""
|
||||||
|
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
|
||||||
|
if _var_name.startswith("@"):
|
||||||
|
return {"value": _obj[_var_name[1:]]}
|
||||||
|
else:
|
||||||
|
return {"value": getattr(_obj, _var_name)}
|
||||||
|
f_get.__doc__ = f_get.__doc__.format(_name, _schema, _example_yaml_value)
|
||||||
|
return f_get
|
||||||
|
|
||||||
|
def _make_f_put(schema_class: Type[KoboldSchema], obj, _var_name, _name, _schema, _example_yaml_value):
|
||||||
|
def f_put(body: schema_class):
|
||||||
|
"""---
|
||||||
|
put:
|
||||||
|
summary: Set {} setting to specified value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: {}
|
||||||
|
example:
|
||||||
|
value: {}
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: EmptySchema
|
||||||
|
{api_validation_error_response}
|
||||||
|
"""
|
||||||
|
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
|
||||||
|
if _var_name.startswith("@"):
|
||||||
|
_obj[_var_name[1:]] = body.value
|
||||||
|
else:
|
||||||
|
setattr(_obj, _var_name, body.value)
|
||||||
|
settingschanged()
|
||||||
|
refresh_settings()
|
||||||
|
return {}
|
||||||
|
f_put.__doc__ = f_put.__doc__.format(_name, _schema, _example_yaml_value, api_validation_error_response=api_validation_error_response)
|
||||||
|
return f_put
|
||||||
|
|
||||||
|
def create_config_endpoint(method="GET", schema="MemorySchema"):
|
||||||
|
_name = globals()[schema].KoboldMeta.name
|
||||||
|
_var_name = globals()[schema].KoboldMeta.var_name
|
||||||
|
_route_name = globals()[schema].KoboldMeta.route_name
|
||||||
|
_obj = globals()[schema].KoboldMeta.obj
|
||||||
|
_example_yaml_value = globals()[schema].KoboldMeta.example_yaml_value
|
||||||
|
_schema = schema
|
||||||
|
f = _make_f_get(_obj, _var_name, _name, _schema, _example_yaml_value) if method == "GET" else _make_f_put(globals()[schema], _obj, _var_name, _name, _schema, _example_yaml_value)
|
||||||
|
f.__name__ = f"{method.lower()}_config_{_name}"
|
||||||
|
f = api_schema_wrap(f)
|
||||||
|
for api in (api_v1,):
|
||||||
|
f = api.route(f"/config/{_route_name}", methods=[method])(f)
|
||||||
|
|
||||||
|
class SoftPromptSettingSchema(KoboldSchema):
|
||||||
|
value: str = fields.String(required=True, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")], metadata={"description": "Soft prompt name, or a string containing only whitespace for no soft prompt. If using the PUT method and no soft prompt is loaded, this will always be the empty string."})
|
||||||
|
|
||||||
|
@api_v1.get("/config/soft_prompt")
|
||||||
|
@api_schema_wrap
|
||||||
|
def get_config_soft_prompt():
|
||||||
|
"""---
|
||||||
|
get:
|
||||||
|
summary: Retrieve the current soft prompt name
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SoftPromptSettingSchema
|
||||||
|
example:
|
||||||
|
value: ""
|
||||||
|
"""
|
||||||
|
return {"value": vars.spfilename.strip()}
|
||||||
|
|
||||||
|
@api_v1.put("/config/soft_prompt")
|
||||||
|
@api_schema_wrap
|
||||||
|
def put_config_soft_prompt(body: SoftPromptSettingSchema):
|
||||||
|
"""---
|
||||||
|
put:
|
||||||
|
summary: Set soft prompt by name
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SoftPromptSettingSchema
|
||||||
|
example:
|
||||||
|
value: ""
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: EmptySchema
|
||||||
|
{api_validation_error_response}
|
||||||
|
"""
|
||||||
|
if vars.allowsp:
|
||||||
|
spRequest(body.value)
|
||||||
|
settingschanged()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
||||||
|
|
||||||
|
def config_endpoint_schema(c: Type[KoboldSchema]):
|
||||||
|
config_endpoint_schemas.append(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class MemorySettingSchema(KoboldSchema):
|
||||||
|
value = fields.String(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "memory"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "memory"
|
||||||
|
name = "memory"
|
||||||
|
example_yaml_value = "Memory"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class AuthorsNoteSettingSchema(KoboldSchema):
|
||||||
|
value = fields.String(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "authors_note"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "authornote"
|
||||||
|
name = "author's note"
|
||||||
|
example_yaml_value = "''"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class AuthorsNoteTemplateSettingSchema(KoboldSchema):
|
||||||
|
value = fields.String(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "authors_note_template"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "authornotetemplate"
|
||||||
|
name = "author's note template"
|
||||||
|
example_yaml_value = "\"[Author's note: <|>]\""
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TopKSamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Integer(validate=validate.Range(min=0), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "top_k"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "top_k"
|
||||||
|
name = "top-k sampling"
|
||||||
|
example_yaml_value = "0"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TopASamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Float(validate=validate.Range(min=0), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "top_a"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "top_a"
|
||||||
|
name = "top-a sampling"
|
||||||
|
example_yaml_value = "0.0"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TopPSamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "top_p"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "top_p"
|
||||||
|
name = "top-p sampling"
|
||||||
|
example_yaml_value = "0.9"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TailFreeSamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "tfs"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "tfs"
|
||||||
|
name = "tail free sampling"
|
||||||
|
example_yaml_value = "1.0"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TypicalSamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "typical"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "typical"
|
||||||
|
name = "typical sampling"
|
||||||
|
example_yaml_value = "1.0"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TemperatureSamplingSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Float(validate=validate.Range(min=0, min_inclusive=False), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "temperature"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "temp"
|
||||||
|
name = "temperature"
|
||||||
|
example_yaml_value = "0.5"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class GensPerActionSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Integer(validate=validate.Range(min=0, max=5), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "n"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "numseqs"
|
||||||
|
name = "Gens Per Action"
|
||||||
|
example_yaml_value = "1"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class MaxLengthSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Integer(validate=validate.Range(min=1, max=512), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "max_length"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "genamt"
|
||||||
|
name = "max length"
|
||||||
|
example_yaml_value = "80"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class MaxContextLengthSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Integer(validate=validate.Range(min=512, max=2048), required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "max_context_length"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "max_length"
|
||||||
|
name = "max context length"
|
||||||
|
example_yaml_value = "2048"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class TrimIncompleteSentencesSettingsSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "frmttriminc"
|
||||||
|
obj = "vars.formatoptns"
|
||||||
|
var_name = "@frmttriminc"
|
||||||
|
name = "trim incomplete sentences (output formatting)"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class RemoveBlankLinesSettingsSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "frmtrmblln"
|
||||||
|
obj = "vars.formatoptns"
|
||||||
|
var_name = "@frmtrmblln"
|
||||||
|
name = "remove blank lines (output formatting)"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class RemoveSpecialCharactersSettingsSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "frmtrmspch"
|
||||||
|
obj = "vars.formatoptns"
|
||||||
|
var_name = "@frmtrmspch"
|
||||||
|
name = "remove special characters (output formatting)"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class SingleLineSettingsSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "singleline"
|
||||||
|
obj = "vars.formatoptns"
|
||||||
|
var_name = "@singleline"
|
||||||
|
name = "single line (output formatting)"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class AddSentenceSpacingSettingsSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "frmtadsnsp"
|
||||||
|
obj = "vars.formatoptns"
|
||||||
|
var_name = "@frmtadsnsp"
|
||||||
|
name = "add sentence spacing (input formatting)"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for schema in config_endpoint_schemas:
|
||||||
|
create_config_endpoint(schema=schema.__name__, method="GET")
|
||||||
|
create_config_endpoint(schema=schema.__name__, method="PUT")
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Final startup commands to launch Flask app
|
# Final startup commands to launch Flask app
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -2316,6 +2316,7 @@ $(document).ready(function(){
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
} else if(msg.cmd == "updatechunk") {
|
} else if(msg.cmd == "updatechunk") {
|
||||||
hideMessage();
|
hideMessage();
|
||||||
|
game_text.attr('contenteditable', allowedit);
|
||||||
if (typeof submit_start !== 'undefined') {
|
if (typeof submit_start !== 'undefined') {
|
||||||
$("#runtime")[0].innerHTML = `Generation time: ${Math.round((Date.now() - submit_start)/1000)} sec`;
|
$("#runtime")[0].innerHTML = `Generation time: ${Math.round((Date.now() - submit_start)/1000)} sec`;
|
||||||
delete submit_start;
|
delete submit_start;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user