Add more endpoints
This commit is contained in:
parent
a93087aecd
commit
1f629ee254
539
aiserver.py
539
aiserver.py
|
@ -448,15 +448,22 @@ def api_catch_out_of_memory_errors(f):
|
|||
return decorated
|
||||
|
||||
def api_schema_wrap(f):
|
||||
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
|
||||
assert inspect.isclass(input_schema) and issubclass(input_schema, Schema)
|
||||
try:
|
||||
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
|
||||
except:
|
||||
HAS_SCHEMA = False
|
||||
else:
|
||||
HAS_SCHEMA = inspect.isclass(input_schema) and issubclass(input_schema, Schema)
|
||||
f = api_format_docstring(f)
|
||||
f = api_catch_out_of_memory_errors(f)
|
||||
@functools.wraps(f)
|
||||
def decorated(*args, **Kwargs):
|
||||
body = request.get_json()
|
||||
schema = input_schema.from_dict(input_schema().load(body))
|
||||
response = f(schema)
|
||||
if HAS_SCHEMA:
|
||||
body = request.get_json()
|
||||
schema = input_schema.from_dict(input_schema().load(body))
|
||||
response = f(schema)
|
||||
else:
|
||||
response = f()
|
||||
if not isinstance(response, Response):
|
||||
response = jsonify(response)
|
||||
return response
|
||||
|
@ -531,9 +538,16 @@ class KoboldAPISpec(APISpec):
|
|||
def delete(self, rule: str, **kwargs):
|
||||
return self.route(rule, methods=["DELETE"], **kwargs)
|
||||
|
||||
tags = [
|
||||
{"name": "generate", "description": "Text generation endpoints"},
|
||||
{"name": "story", "description": "Endpoints for managing the story in the KoboldAI GUI"},
|
||||
{"name": "config", "description": "Allows you to get/set various setting values"},
|
||||
]
|
||||
|
||||
api_v1 = KoboldAPISpec(
|
||||
version="1.0.0",
|
||||
prefixes=["/api/v1", "/api/latest"],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
#==================================================================#
|
||||
|
@ -3763,9 +3777,9 @@ def check_for_backend_compilation():
|
|||
break
|
||||
vars.checking = False
|
||||
|
||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False):
|
||||
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False):
|
||||
# Ignore new submissions if the AI is currently busy
|
||||
if(vars.aibusy):
|
||||
if(not vars.standalone and vars.aibusy):
|
||||
return
|
||||
|
||||
while(True):
|
||||
|
@ -3797,20 +3811,21 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
|
||||
if(not vars.gamestarted):
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
if(not no_generate):
|
||||
execute_inmod()
|
||||
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
||||
data = vars.submission
|
||||
if(not force_submit and len(data.strip()) == 0):
|
||||
assert False
|
||||
# Start the game
|
||||
vars.gamestarted = True
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
|
||||
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
|
||||
# Save this first action as the prompt
|
||||
vars.prompt = data
|
||||
# Clear the startup text from game screen
|
||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
|
||||
calcsubmit(data) # Run the first action through the generator
|
||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
|
@ -3822,7 +3837,8 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
vars.prompt = data if len(data) > 0 else '"'
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
if(not no_generate):
|
||||
execute_outmod()
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(vars.numseqs):
|
||||
|
@ -3856,7 +3872,8 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
if(vars.actionmode == 0):
|
||||
data = applyinputformatting(data)
|
||||
vars.submission = data
|
||||
execute_inmod()
|
||||
if(not no_generate):
|
||||
execute_inmod()
|
||||
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
|
||||
data = vars.submission
|
||||
# Dont append submission if it's a blank/continue action
|
||||
|
@ -3886,7 +3903,7 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
update_story_chunk('last')
|
||||
send_debug()
|
||||
|
||||
if(not vars.noai and vars.lua_koboldbridge.generating):
|
||||
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating):
|
||||
# Off to the tokenizer!
|
||||
calcsubmit(data)
|
||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
|
||||
|
@ -3897,23 +3914,24 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
|
||||
break
|
||||
else:
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
if(not no_generate):
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = ""
|
||||
execute_outmod()
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
genout = []
|
||||
for i in range(vars.numseqs):
|
||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1] if not no_generate else ""})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
|
||||
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
continue
|
||||
else:
|
||||
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
data = ""
|
||||
force_submit = True
|
||||
|
@ -4011,8 +4029,6 @@ def apiactionsubmit(data, use_memory=False):
|
|||
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
|
||||
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
|
||||
|
||||
genout = [applyoutputformatting(txt) for txt in genout]
|
||||
|
||||
return genout
|
||||
|
||||
#==================================================================#
|
||||
|
@ -6676,10 +6692,22 @@ def get_files_folders(starting_folder):
|
|||
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
|
||||
|
||||
|
||||
class EmptySchema(KoboldSchema):
|
||||
pass
|
||||
|
||||
class BasicTextResultInnerSchema(KoboldSchema):
|
||||
text: str = fields.String(required=True)
|
||||
|
||||
class BasicTextResultSchema(KoboldSchema):
|
||||
result: BasicTextResultInnerSchema = fields.Nested(BasicTextResultInnerSchema)
|
||||
|
||||
class BasicErrorSchema(KoboldSchema):
|
||||
msg: str = fields.String(required=True)
|
||||
type: str = fields.String(required=True)
|
||||
|
||||
class StoryEmptyErrorSchema(KoboldSchema):
|
||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||
|
||||
class OutOfMemoryErrorSchema(KoboldSchema):
|
||||
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
|
||||
|
||||
|
@ -6754,7 +6782,7 @@ class SamplerSettingsSchema(KoboldSchema):
|
|||
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
|
||||
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
|
||||
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
|
||||
top_k: Optional[int] = fields.Int(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
|
||||
top_k: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
|
||||
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
|
||||
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
|
||||
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
|
||||
|
@ -6781,7 +6809,8 @@ class GenerationInputSchema(SamplerSettingsSchema):
|
|||
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text. NOTE: Currently unimplemented."})
|
||||
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
|
||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=2048), metadata={"description": "Number of tokens to generate."})
|
||||
max_length: int = fields.Integer(validate=validate.Range(min=1, max=512), metadata={"description": "Number of tokens to generate."})
|
||||
max_context_length: int = fields.Integer(validate=validate.Range(min=512, max=2048), metadata={"description": "Maximum number of tokens to send to the model."})
|
||||
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
|
||||
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all output formatting options, overriding their individual enabled/disabled states."})
|
||||
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing."})
|
||||
|
@ -6800,8 +6829,8 @@ class GenerationOutputSchema(KoboldSchema):
|
|||
def _generate_text(body: GenerationInputSchema):
|
||||
if vars.aibusy or vars.genseqs:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
"type": "service_unavailable",
|
||||
"msg": "Server is busy; please try again later.",
|
||||
"type": "service_unavailable",
|
||||
}}), mimetype="application/json", status=503))
|
||||
if body.use_story:
|
||||
raise NotImplementedError("use_story is not currently supported.")
|
||||
|
@ -6810,24 +6839,25 @@ def _generate_text(body: GenerationInputSchema):
|
|||
if body.use_userscripts:
|
||||
raise NotImplementedError("use_userscripts is not currently supported.")
|
||||
mapping = {
|
||||
"rep_pen": (vars, "rep_pen"),
|
||||
"rep_pen_range": (vars, "rep_pen_range"),
|
||||
"rep_pen_slope": (vars, "rep_pen_slope"),
|
||||
"top_k": (vars, "top_k"),
|
||||
"top_a": (vars, "top_a"),
|
||||
"top_p": (vars, "top_p"),
|
||||
"tfs": (vars, "tfs"),
|
||||
"typical": (vars, "typical"),
|
||||
"temperature": (vars, "temp"),
|
||||
"frmtadnsp": (vars.formatoptns, "@frmtadnsp"),
|
||||
"frmttriminc": (vars.formatoptns, "@frmttriminc"),
|
||||
"frmtrmblln": (vars.formatoptns, "@frmtrmblln"),
|
||||
"frmtrmspch": (vars.formatoptns, "@frmtrmspch"),
|
||||
"singleline": (vars.formatoptns, "@singleline"),
|
||||
"disable_input_formatting": (vars, "disable_input_formatting"),
|
||||
"disable_output_formatting": (vars, "disable_output_formatting"),
|
||||
"max_length": (vars, "genamt"),
|
||||
"n": (vars, "numseqs"),
|
||||
"rep_pen": ("vars", "rep_pen"),
|
||||
"rep_pen_range": ("vars", "rep_pen_range"),
|
||||
"rep_pen_slope": ("vars", "rep_pen_slope"),
|
||||
"top_k": ("vars", "top_k"),
|
||||
"top_a": ("vars", "top_a"),
|
||||
"top_p": ("vars", "top_p"),
|
||||
"tfs": ("vars", "tfs"),
|
||||
"typical": ("vars", "typical"),
|
||||
"temperature": ("vars", "temp"),
|
||||
"frmtadnsp": ("vars.formatoptns", "@frmtadnsp"),
|
||||
"frmttriminc": ("vars.formatoptns", "@frmttriminc"),
|
||||
"frmtrmblln": ("vars.formatoptns", "@frmtrmblln"),
|
||||
"frmtrmspch": ("vars.formatoptns", "@frmtrmspch"),
|
||||
"singleline": ("vars.formatoptns", "@singleline"),
|
||||
"disable_input_formatting": ("vars", "disable_input_formatting"),
|
||||
"disable_output_formatting": ("vars", "disable_output_formatting"),
|
||||
"max_length": ("vars", "genamt"),
|
||||
"max_context_length": ("vars", "max_length"),
|
||||
"n": ("vars", "numseqs"),
|
||||
}
|
||||
saved_settings = {}
|
||||
set_aibusy(1)
|
||||
|
@ -6836,13 +6866,15 @@ def _generate_text(body: GenerationInputSchema):
|
|||
_standalone = vars.standalone
|
||||
vars.standalone = True
|
||||
for key, entry in mapping.items():
|
||||
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
|
||||
if getattr(body, key, None) is not None:
|
||||
if entry[1].startswith("@"):
|
||||
saved_settings[key] = entry[0][entry[1][1:]]
|
||||
entry[0][entry[1][1:]] = getattr(body, key)
|
||||
saved_settings[key] = obj[entry[1][1:]]
|
||||
obj[entry[1][1:]] = getattr(body, key)
|
||||
print(entry[1][1:], obj[entry[1][1:]])
|
||||
else:
|
||||
saved_settings[key] = getattr(entry[0], entry[1])
|
||||
setattr(entry[0], entry[1], getattr(body, key))
|
||||
saved_settings[key] = getattr(obj, entry[1])
|
||||
setattr(obj, entry[1], getattr(body, key))
|
||||
try:
|
||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||
if any(q in body.soft_prompt for q in ("/", "\\")):
|
||||
|
@ -6854,13 +6886,14 @@ def _generate_text(body: GenerationInputSchema):
|
|||
finally:
|
||||
for key in saved_settings:
|
||||
entry = mapping[key]
|
||||
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
|
||||
if getattr(body, key, None) is not None:
|
||||
if entry[1].startswith("@"):
|
||||
if entry[0][entry[1][1:]] == getattr(body, key):
|
||||
entry[0][entry[1][1:]] = saved_settings[key]
|
||||
if obj[entry[1][1:]] == getattr(body, key):
|
||||
obj[entry[1][1:]] = saved_settings[key]
|
||||
else:
|
||||
if getattr(entry[0], entry[1]) == getattr(body, key):
|
||||
setattr(entry[0], entry[1], saved_settings[key])
|
||||
if getattr(obj, entry[1]) == getattr(body, key):
|
||||
setattr(obj, entry[1], saved_settings[key])
|
||||
vars.disable_set_aibusy = disable_set_aibusy
|
||||
vars.standalone = _standalone
|
||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||
|
@ -6871,9 +6904,11 @@ def _generate_text(body: GenerationInputSchema):
|
|||
@api_v1.post("/generate")
|
||||
@api_schema_wrap
|
||||
def post_completion_standalone(body: GenerationInputSchema):
|
||||
r"""Generate text
|
||||
---
|
||||
"""---
|
||||
post:
|
||||
summary: Generate text
|
||||
tags:
|
||||
- generate
|
||||
description: |-2
|
||||
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
|
||||
|
||||
|
@ -6909,6 +6944,404 @@ def post_completion_standalone(body: GenerationInputSchema):
|
|||
return _generate_text(body)
|
||||
|
||||
|
||||
def prompt_validator(prompt: str):
|
||||
if len(prompt.strip()) == 0:
|
||||
raise ValidationError("String does not match expected pattern.")
|
||||
|
||||
class SubmissionInputSchema(KoboldSchema):
|
||||
prompt: str = fields.String(required=True, validate=prompt_validator, metadata={"pattern": r"^.*\S.*$", "description": "This is the submission."})
|
||||
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
|
||||
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
|
||||
|
||||
@api_v1.post("/story/end")
|
||||
@api_schema_wrap
|
||||
def post_story_end(body: SubmissionInputSchema):
|
||||
"""---
|
||||
post:
|
||||
summary: Add an action to the end of the story
|
||||
tags:
|
||||
- story
|
||||
description: |-2
|
||||
Inserts a single action at the end of the story in the KoboldAI GUI without generating text.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: SubmissionInputSchema
|
||||
example:
|
||||
prompt: |-2
|
||||
This is some text to put at the end of the story.
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: EmptySchema
|
||||
{api_validation_error_response}
|
||||
{api_server_busy_response}
|
||||
"""
|
||||
if vars.aibusy or vars.genseqs:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
"msg": "Server is busy; please try again later.",
|
||||
"type": "service_unavailable",
|
||||
}}), mimetype="application/json", status=503))
|
||||
set_aibusy(1)
|
||||
disable_set_aibusy = vars.disable_set_aibusy
|
||||
vars.disable_set_aibusy = True
|
||||
_standalone = vars.standalone
|
||||
vars.standalone = True
|
||||
numseqs = vars.numseqs
|
||||
vars.numseqs = 1
|
||||
try:
|
||||
actionsubmit(body.prompt, force_submit=True, no_generate=True)
|
||||
finally:
|
||||
vars.disable_set_aibusy = disable_set_aibusy
|
||||
vars.standalone = _standalone
|
||||
vars.numseqs = numseqs
|
||||
set_aibusy(0)
|
||||
return {}
|
||||
|
||||
|
||||
@api_v1.get("/story/end")
|
||||
@api_schema_wrap
|
||||
def get_story_end():
|
||||
"""---
|
||||
get:
|
||||
summary: Retrieve the last action of the story
|
||||
tags:
|
||||
- story
|
||||
description: |-2
|
||||
Returns the last action of the story in the KoboldAI GUI.
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: BasicTextResultSchema
|
||||
510:
|
||||
description: Story is empty
|
||||
content:
|
||||
application/json:
|
||||
schema: StoryEmptyErrorSchema
|
||||
example:
|
||||
detail:
|
||||
msg: Could not retrieve the last action of the story because the story is empty.
|
||||
type: story_empty
|
||||
"""
|
||||
if not vars.gamestarted:
|
||||
abort(Response(json.dumps({"detail": {
|
||||
"msg": "Could not retrieve the last action of the story because the story is empty.",
|
||||
"type": "story_empty",
|
||||
}}), mimetype="application/json", status=510))
|
||||
if len(vars.actions) == 0:
|
||||
return {"result": {"text": vars.prompt}}
|
||||
return {"result": {"text": vars.actions[vars.actions.get_last_key()]}}
|
||||
|
||||
|
||||
def _make_f_get(obj, _var_name, _name, _schema, _example_yaml_value):
|
||||
def f_get():
|
||||
"""---
|
||||
get:
|
||||
summary: Retrieve the current {} setting value
|
||||
tags:
|
||||
- config
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: {}
|
||||
example:
|
||||
value: {}
|
||||
"""
|
||||
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
|
||||
if _var_name.startswith("@"):
|
||||
return {"value": _obj[_var_name[1:]]}
|
||||
else:
|
||||
return {"value": getattr(_obj, _var_name)}
|
||||
f_get.__doc__ = f_get.__doc__.format(_name, _schema, _example_yaml_value)
|
||||
return f_get
|
||||
|
||||
def _make_f_put(schema_class: Type[KoboldSchema], obj, _var_name, _name, _schema, _example_yaml_value):
|
||||
def f_put(body: schema_class):
|
||||
"""---
|
||||
put:
|
||||
summary: Set {} setting to specified value
|
||||
tags:
|
||||
- config
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: {}
|
||||
example:
|
||||
value: {}
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: EmptySchema
|
||||
{api_validation_error_response}
|
||||
"""
|
||||
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
|
||||
if _var_name.startswith("@"):
|
||||
_obj[_var_name[1:]] = body.value
|
||||
else:
|
||||
setattr(_obj, _var_name, body.value)
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
return {}
|
||||
f_put.__doc__ = f_put.__doc__.format(_name, _schema, _example_yaml_value, api_validation_error_response=api_validation_error_response)
|
||||
return f_put
|
||||
|
||||
def create_config_endpoint(method="GET", schema="MemorySchema"):
|
||||
_name = globals()[schema].KoboldMeta.name
|
||||
_var_name = globals()[schema].KoboldMeta.var_name
|
||||
_route_name = globals()[schema].KoboldMeta.route_name
|
||||
_obj = globals()[schema].KoboldMeta.obj
|
||||
_example_yaml_value = globals()[schema].KoboldMeta.example_yaml_value
|
||||
_schema = schema
|
||||
f = _make_f_get(_obj, _var_name, _name, _schema, _example_yaml_value) if method == "GET" else _make_f_put(globals()[schema], _obj, _var_name, _name, _schema, _example_yaml_value)
|
||||
f.__name__ = f"{method.lower()}_config_{_name}"
|
||||
f = api_schema_wrap(f)
|
||||
for api in (api_v1,):
|
||||
f = api.route(f"/config/{_route_name}", methods=[method])(f)
|
||||
|
||||
class SoftPromptSettingSchema(KoboldSchema):
|
||||
value: str = fields.String(required=True, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")], metadata={"description": "Soft prompt name, or a string containing only whitespace for no soft prompt. If using the PUT method and no soft prompt is loaded, this will always be the empty string."})
|
||||
|
||||
@api_v1.get("/config/soft_prompt")
|
||||
@api_schema_wrap
|
||||
def get_config_soft_prompt():
|
||||
"""---
|
||||
get:
|
||||
summary: Retrieve the current soft prompt name
|
||||
tags:
|
||||
- config
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: SoftPromptSettingSchema
|
||||
example:
|
||||
value: ""
|
||||
"""
|
||||
return {"value": vars.spfilename.strip()}
|
||||
|
||||
@api_v1.put("/config/soft_prompt")
|
||||
@api_schema_wrap
|
||||
def put_config_soft_prompt(body: SoftPromptSettingSchema):
|
||||
"""---
|
||||
put:
|
||||
summary: Set soft prompt by name
|
||||
tags:
|
||||
- config
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: SoftPromptSettingSchema
|
||||
example:
|
||||
value: ""
|
||||
responses:
|
||||
200:
|
||||
description: Successful request
|
||||
content:
|
||||
application/json:
|
||||
schema: EmptySchema
|
||||
{api_validation_error_response}
|
||||
"""
|
||||
if vars.allowsp:
|
||||
spRequest(body.value)
|
||||
settingschanged()
|
||||
return {}
|
||||
|
||||
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
||||
|
||||
def config_endpoint_schema(c: Type[KoboldSchema]):
|
||||
config_endpoint_schemas.append(c)
|
||||
return c
|
||||
|
||||
|
||||
@config_endpoint_schema
|
||||
class MemorySettingSchema(KoboldSchema):
|
||||
value = fields.String(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "memory"
|
||||
obj = "vars"
|
||||
var_name = "memory"
|
||||
name = "memory"
|
||||
example_yaml_value = "Memory"
|
||||
|
||||
@config_endpoint_schema
|
||||
class AuthorsNoteSettingSchema(KoboldSchema):
|
||||
value = fields.String(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "authors_note"
|
||||
obj = "vars"
|
||||
var_name = "authornote"
|
||||
name = "author's note"
|
||||
example_yaml_value = "''"
|
||||
|
||||
@config_endpoint_schema
|
||||
class AuthorsNoteTemplateSettingSchema(KoboldSchema):
|
||||
value = fields.String(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "authors_note_template"
|
||||
obj = "vars"
|
||||
var_name = "authornotetemplate"
|
||||
name = "author's note template"
|
||||
example_yaml_value = "\"[Author's note: <|>]\""
|
||||
|
||||
@config_endpoint_schema
|
||||
class TopKSamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Integer(validate=validate.Range(min=0), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "top_k"
|
||||
obj = "vars"
|
||||
var_name = "top_k"
|
||||
name = "top-k sampling"
|
||||
example_yaml_value = "0"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TopASamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Float(validate=validate.Range(min=0), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "top_a"
|
||||
obj = "vars"
|
||||
var_name = "top_a"
|
||||
name = "top-a sampling"
|
||||
example_yaml_value = "0.0"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TopPSamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "top_p"
|
||||
obj = "vars"
|
||||
var_name = "top_p"
|
||||
name = "top-p sampling"
|
||||
example_yaml_value = "0.9"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TailFreeSamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "tfs"
|
||||
obj = "vars"
|
||||
var_name = "tfs"
|
||||
name = "tail free sampling"
|
||||
example_yaml_value = "1.0"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TypicalSamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "typical"
|
||||
obj = "vars"
|
||||
var_name = "typical"
|
||||
name = "typical sampling"
|
||||
example_yaml_value = "1.0"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TemperatureSamplingSettingSchema(KoboldSchema):
|
||||
value = fields.Float(validate=validate.Range(min=0, min_inclusive=False), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "temperature"
|
||||
obj = "vars"
|
||||
var_name = "temp"
|
||||
name = "temperature"
|
||||
example_yaml_value = "0.5"
|
||||
|
||||
@config_endpoint_schema
|
||||
class GensPerActionSettingSchema(KoboldSchema):
|
||||
value = fields.Integer(validate=validate.Range(min=0, max=5), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "n"
|
||||
obj = "vars"
|
||||
var_name = "numseqs"
|
||||
name = "Gens Per Action"
|
||||
example_yaml_value = "1"
|
||||
|
||||
@config_endpoint_schema
|
||||
class MaxLengthSettingSchema(KoboldSchema):
|
||||
value = fields.Integer(validate=validate.Range(min=1, max=512), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "max_length"
|
||||
obj = "vars"
|
||||
var_name = "genamt"
|
||||
name = "max length"
|
||||
example_yaml_value = "80"
|
||||
|
||||
@config_endpoint_schema
|
||||
class MaxContextLengthSettingSchema(KoboldSchema):
|
||||
value = fields.Integer(validate=validate.Range(min=512, max=2048), required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "max_context_length"
|
||||
obj = "vars"
|
||||
var_name = "max_length"
|
||||
name = "max context length"
|
||||
example_yaml_value = "2048"
|
||||
|
||||
@config_endpoint_schema
|
||||
class TrimIncompleteSentencesSettingsSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "frmttriminc"
|
||||
obj = "vars.formatoptns"
|
||||
var_name = "@frmttriminc"
|
||||
name = "trim incomplete sentences (output formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
@config_endpoint_schema
|
||||
class RemoveBlankLinesSettingsSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "frmtrmblln"
|
||||
obj = "vars.formatoptns"
|
||||
var_name = "@frmtrmblln"
|
||||
name = "remove blank lines (output formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
@config_endpoint_schema
|
||||
class RemoveSpecialCharactersSettingsSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "frmtrmspch"
|
||||
obj = "vars.formatoptns"
|
||||
var_name = "@frmtrmspch"
|
||||
name = "remove special characters (output formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
@config_endpoint_schema
|
||||
class SingleLineSettingsSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "singleline"
|
||||
obj = "vars.formatoptns"
|
||||
var_name = "@singleline"
|
||||
name = "single line (output formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
@config_endpoint_schema
|
||||
class AddSentenceSpacingSettingsSchema(KoboldSchema):
|
||||
value = fields.Boolean(required=True)
|
||||
class KoboldMeta:
|
||||
route_name = "frmtadsnsp"
|
||||
obj = "vars.formatoptns"
|
||||
var_name = "@frmtadsnsp"
|
||||
name = "add sentence spacing (input formatting)"
|
||||
example_yaml_value = "false"
|
||||
|
||||
|
||||
|
||||
for schema in config_endpoint_schemas:
|
||||
create_config_endpoint(schema=schema.__name__, method="GET")
|
||||
create_config_endpoint(schema=schema.__name__, method="PUT")
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Final startup commands to launch Flask app
|
||||
#==================================================================#
|
||||
|
|
|
@ -2316,6 +2316,7 @@ $(document).ready(function(){
|
|||
scrollToBottom();
|
||||
} else if(msg.cmd == "updatechunk") {
|
||||
hideMessage();
|
||||
game_text.attr('contenteditable', allowedit);
|
||||
if (typeof submit_start !== 'undefined') {
|
||||
$("#runtime")[0].innerHTML = `Generation time: ${Math.round((Date.now() - submit_start)/1000)} sec`;
|
||||
delete submit_start;
|
||||
|
|
Loading…
Reference in New Issue