mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Work on import overhaul
This commit is contained in:
185
aiserver.py
185
aiserver.py
@@ -6,6 +6,7 @@
|
|||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
|
from dataclasses import dataclass
|
||||||
import eventlet
|
import eventlet
|
||||||
eventlet.monkey_patch(all=True, thread=False, os=False)
|
eventlet.monkey_patch(all=True, thread=False, os=False)
|
||||||
import os
|
import os
|
||||||
@@ -239,6 +240,182 @@ class Send_to_socketio(object):
|
|||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImportBuffer:
|
||||||
|
# Singleton!!!
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
memory: Optional[str] = None
|
||||||
|
authors_note: Optional[str] = None
|
||||||
|
world_infos: Optional[dict] = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptPlaceholder:
|
||||||
|
id: str
|
||||||
|
order: Optional[int] = None
|
||||||
|
default: Optional[str] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
value: Optional[str] = None
|
||||||
|
|
||||||
|
def to_json(self) -> dict:
|
||||||
|
return {key: getattr(self, key) for key in [
|
||||||
|
"id",
|
||||||
|
"order",
|
||||||
|
"default",
|
||||||
|
"title",
|
||||||
|
"description"
|
||||||
|
]}
|
||||||
|
|
||||||
|
def request_client_configuration(self, placeholders: list[PromptPlaceholder]) -> None:
|
||||||
|
emit("request_prompt_config", [x.to_json() for x in placeholders], broadcast=False, room="UI_2")
|
||||||
|
|
||||||
|
def extract_placeholders(self, text: str) -> list[PromptPlaceholder]:
|
||||||
|
placeholders = []
|
||||||
|
|
||||||
|
for match in re.finditer(r"\${(.*?)}", text):
|
||||||
|
ph_text = match.group(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ph_order, ph_text = ph_text.split("#")
|
||||||
|
except ValueError:
|
||||||
|
ph_order = None
|
||||||
|
|
||||||
|
if "[" not in ph_text:
|
||||||
|
ph_id = ph_text
|
||||||
|
|
||||||
|
# Apparently, none of these characters are supported:
|
||||||
|
# "${}[]#:@^|", however I have found some prompts using these,
|
||||||
|
# so they will be allowed.
|
||||||
|
for char in "${}[]":
|
||||||
|
if char in ph_text:
|
||||||
|
print("[eph] Weird char")
|
||||||
|
print(f"{char=}")
|
||||||
|
print(f"{ph_id=}")
|
||||||
|
return
|
||||||
|
|
||||||
|
placeholders.append(self.PromptPlaceholder(
|
||||||
|
id=ph_id,
|
||||||
|
order=int(ph_order) if ph_order else None,
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
ph_id, _ = ph_text.split("[")
|
||||||
|
ph_text = ph_text.replace(ph_id, "", 1)
|
||||||
|
|
||||||
|
# Match won't match it for some reason (???), so we use finditer and next()
|
||||||
|
try:
|
||||||
|
default_match = next(re.finditer(r"\[(.*?)\]", ph_text))
|
||||||
|
except StopIteration:
|
||||||
|
print("[eph] Weird brackets")
|
||||||
|
return placeholders
|
||||||
|
|
||||||
|
ph_default = default_match.group(1)
|
||||||
|
ph_text = ph_text.replace(default_match.group(0), "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ph_title, ph_desc = ph_text.split(":")
|
||||||
|
except ValueError:
|
||||||
|
ph_title = ph_text or None
|
||||||
|
ph_desc=None
|
||||||
|
|
||||||
|
placeholders.append(self.PromptPlaceholder(
|
||||||
|
id=ph_id,
|
||||||
|
order=int(ph_order) if ph_order else None,
|
||||||
|
default=ph_default,
|
||||||
|
title=ph_title,
|
||||||
|
description=ph_desc
|
||||||
|
))
|
||||||
|
return placeholders
|
||||||
|
|
||||||
|
def _replace_placeholders(self, text: str, ph_ids: dict):
|
||||||
|
for ph_id, value in ph_ids.items():
|
||||||
|
pattern = "\${(?:\d#)?{}.*}".format(ph_id)
|
||||||
|
for ph_text in re.findall(pattern, text):
|
||||||
|
text = text.replace(ph_text, value)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def replace_placeholders(self, ph_ids: dict):
|
||||||
|
self.prompt = self._replace_placeholders(self.prompt, ph_ids)
|
||||||
|
self.memory = self._replace_placeholders(self.memory, ph_ids)
|
||||||
|
self.authors_note = self._replace_placeholders(self.authors_note, ph_ids)
|
||||||
|
|
||||||
|
for i in range(len(self.world_infos)):
|
||||||
|
for key in ["content", "comment"]:
|
||||||
|
self.world_infos[i][key] = self._replace_placeholders(self.world_infos[i][key])
|
||||||
|
|
||||||
|
def from_club(self, club_id):
|
||||||
|
# Maybe it is a better to parse the NAI Scenario (if available), it has more data
|
||||||
|
r = requests.get(f"https://aetherroom.club/api/{club_id}")
|
||||||
|
|
||||||
|
if not r.ok:
|
||||||
|
# TODO: Show error message on client
|
||||||
|
print(f"[import] Got {r.status_code} on request to club :^(")
|
||||||
|
return
|
||||||
|
|
||||||
|
j = r.json()
|
||||||
|
|
||||||
|
self.prompt = j["promptContent"]
|
||||||
|
self.memory = j["memory"]
|
||||||
|
self.authors_note = j["authorsNote"]
|
||||||
|
|
||||||
|
self.world_infos = []
|
||||||
|
|
||||||
|
for wi in j["worldInfos"]:
|
||||||
|
self.world_infos.append({
|
||||||
|
"key_list": wi["keysList"],
|
||||||
|
"keysecondary": [],
|
||||||
|
"content": wi["entry"],
|
||||||
|
"comment": "",
|
||||||
|
"folder": wi.get("folder", None),
|
||||||
|
"num": 0,
|
||||||
|
"init": True,
|
||||||
|
"selective": wi.get("selective", False),
|
||||||
|
"constant": wi.get("constant", False),
|
||||||
|
"uid": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
placeholders = self.extract_placeholders(self.prompt)
|
||||||
|
if not placeholders:
|
||||||
|
self.commit()
|
||||||
|
else:
|
||||||
|
self.request_client_configuration(placeholders)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
# Push buffer story to actual story
|
||||||
|
exitModes()
|
||||||
|
|
||||||
|
koboldai_vars.create_story("")
|
||||||
|
koboldai_vars.gamestarted = True
|
||||||
|
koboldai_vars.prompt = self.prompt
|
||||||
|
koboldai_vars.memory = self.memory or ""
|
||||||
|
koboldai_vars.authornote = self.authors_note or ""
|
||||||
|
|
||||||
|
# ???: Was this supposed to increment?
|
||||||
|
num = 0
|
||||||
|
for wi in self.world_infos:
|
||||||
|
# koboldai_vars.worldinfo += self.world_infos
|
||||||
|
|
||||||
|
koboldai_vars.worldinfo_v2.add_item(
|
||||||
|
wi["key_list"][0],
|
||||||
|
wi["key_list"],
|
||||||
|
wi.get("keysecondary", []),
|
||||||
|
wi.get("folder", "root"),
|
||||||
|
wi.get("constant", False),
|
||||||
|
wi["content"],
|
||||||
|
wi.get("comment", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset current save
|
||||||
|
koboldai_vars.savedir = getcwd()+"\\stories"
|
||||||
|
|
||||||
|
# Refresh game screen
|
||||||
|
koboldai_vars.laststory = None
|
||||||
|
setgamesaved(False)
|
||||||
|
sendwi()
|
||||||
|
refresh_story()
|
||||||
|
|
||||||
|
import_buffer = ImportBuffer()
|
||||||
|
|
||||||
# Set logging level to reduce chatter from Flask
|
# Set logging level to reduce chatter from Flask
|
||||||
import logging
|
import logging
|
||||||
@@ -7131,6 +7308,11 @@ def get_files_sorted(path, sort, desc=False):
|
|||||||
return [key[0] for key in sorted(data.items(), key=lambda kv: (kv[1], kv[0]), reverse=desc)]
|
return [key[0] for key in sorted(data.items(), key=lambda kv: (kv[1], kv[0]), reverse=desc)]
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on("configure_prompt")
|
||||||
|
def UI_2_configure_prompt(data):
|
||||||
|
import_buffer.replace_placeholders(data)
|
||||||
|
import_buffer.commit()
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Event triggered when browser SocketIO detects a variable change
|
# Event triggered when browser SocketIO detects a variable change
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
@@ -7789,7 +7971,8 @@ def UI_2_unload_userscripts(data):
|
|||||||
def UI_2_load_aidg_club(data):
|
def UI_2_load_aidg_club(data):
|
||||||
if koboldai_vars.debug:
|
if koboldai_vars.debug:
|
||||||
print("Load aidg.club: {}".format(data))
|
print("Load aidg.club: {}".format(data))
|
||||||
importAidgRequest(data)
|
import_buffer.from_club(data)
|
||||||
|
# importAidgRequest(data)
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
Reference in New Issue
Block a user