diff --git a/aiserver.py b/aiserver.py index 4a8c98fc..37e505af 100644 --- a/aiserver.py +++ b/aiserver.py @@ -6,6 +6,7 @@ #==================================================================# # External packages +from dataclasses import dataclass import eventlet eventlet.monkey_patch(all=True, thread=False, os=False) import os @@ -239,6 +240,182 @@ class Send_to_socketio(object): def flush(self): pass + +@dataclass +class ImportBuffer: + # Singleton!!! + prompt: Optional[str] = None + memory: Optional[str] = None + authors_note: Optional[str] = None + world_infos: Optional[dict] = None + + @dataclass + class PromptPlaceholder: + id: str + order: Optional[int] = None + default: Optional[str] = None + title: Optional[str] = None + description: Optional[str] = None + value: Optional[str] = None + + def to_json(self) -> dict: + return {key: getattr(self, key) for key in [ + "id", + "order", + "default", + "title", + "description" + ]} + + def request_client_configuration(self, placeholders: list[PromptPlaceholder]) -> None: + emit("request_prompt_config", [x.to_json() for x in placeholders], broadcast=False, room="UI_2") + + def extract_placeholders(self, text: str) -> list[PromptPlaceholder]: + placeholders = [] + + for match in re.finditer(r"\${(.*?)}", text): + ph_text = match.group(1) + + try: + ph_order, ph_text = ph_text.split("#") + except ValueError: + ph_order = None + + if "[" not in ph_text: + ph_id = ph_text + + # Apparently, none of these characters are supported: + # "${}[]#:@^|", however I have found some prompts using these, + # so they will be allowed. + for char in "${}[]": + if char in ph_text: + print("[eph] Weird char") + print(f"{char=}") + print(f"{ph_id=}") + return + + placeholders.append(self.PromptPlaceholder( + id=ph_id, + order=int(ph_order) if ph_order else None, + )) + continue + + ph_id, _ = ph_text.split("[") + ph_text = ph_text.replace(ph_id, "", 1) + + # Match won't match it for some reason (???), so we use finditer and next() + try: + default_match = next(re.finditer(r"\[(.*?)\]", ph_text)) + except StopIteration: + print("[eph] Weird brackets") + return placeholders + + ph_default = default_match.group(1) + ph_text = ph_text.replace(default_match.group(0), "") + + try: + ph_title, ph_desc = ph_text.split(":") + except ValueError: + ph_title = ph_text or None + ph_desc=None + + placeholders.append(self.PromptPlaceholder( + id=ph_id, + order=int(ph_order) if ph_order else None, + default=ph_default, + title=ph_title, + description=ph_desc + )) + return placeholders + + def _replace_placeholders(self, text: str, ph_ids: dict): + for ph_id, value in ph_ids.items(): + pattern = "\${(?:\d#)?{}.*}".format(ph_id) + for ph_text in re.findall(pattern, text): + text = text.replace(ph_text, value) + return text + + def replace_placeholders(self, ph_ids: dict): + self.prompt = self._replace_placeholders(self.prompt, ph_ids) + self.memory = self._replace_placeholders(self.memory, ph_ids) + self.authors_note = self._replace_placeholders(self.authors_note, ph_ids) + + for i in range(len(self.world_infos)): + for key in ["content", "comment"]: + self.world_infos[i][key] = self._replace_placeholders(self.world_infos[i][key]) + + def from_club(self, club_id): + # Maybe it is a better to parse the NAI Scenario (if available), it has more data + r = requests.get(f"https://aetherroom.club/api/{club_id}") + + if not r.ok: + # TODO: Show error message on client + print(f"[import] Got {r.status_code} on request to club :^(") + return + + j = r.json() + + self.prompt = j["promptContent"] + self.memory = j["memory"] + self.authors_note = j["authorsNote"] + + self.world_infos = [] + + for wi in j["worldInfos"]: + self.world_infos.append({ + "key_list": wi["keysList"], + "keysecondary": [], + "content": wi["entry"], + "comment": "", + "folder": wi.get("folder", None), + "num": 0, + "init": True, + "selective": wi.get("selective", False), + "constant": wi.get("constant", False), + "uid": None, + }) + + placeholders = self.extract_placeholders(self.prompt) + if not placeholders: + self.commit() + else: + self.request_client_configuration(placeholders) + + def commit(self): + # Push buffer story to actual story + exitModes() + + koboldai_vars.create_story("") + koboldai_vars.gamestarted = True + koboldai_vars.prompt = self.prompt + koboldai_vars.memory = self.memory or "" + koboldai_vars.authornote = self.authors_note or "" + + # ???: Was this supposed to increment? + num = 0 + for wi in self.world_infos: + # koboldai_vars.worldinfo += self.world_infos + + koboldai_vars.worldinfo_v2.add_item( + wi["key_list"][0], + wi["key_list"], + wi.get("keysecondary", []), + wi.get("folder", "root"), + wi.get("constant", False), + wi["content"], + wi.get("comment", "") + ) + + # Reset current save + koboldai_vars.savedir = getcwd()+"\\stories" + + # Refresh game screen + koboldai_vars.laststory = None + setgamesaved(False) + sendwi() + refresh_story() + +import_buffer = ImportBuffer() # Set logging level to reduce chatter from Flask import logging @@ -7131,6 +7308,11 @@ def get_files_sorted(path, sort, desc=False): return [key[0] for key in sorted(data.items(), key=lambda kv: (kv[1], kv[0]), reverse=desc)] +@socketio.on("configure_prompt") +def UI_2_configure_prompt(data): + import_buffer.replace_placeholders(data) + import_buffer.commit() + #==================================================================# # Event triggered when browser SocketIO detects a variable change #==================================================================# @@ -7789,7 +7971,8 @@ def UI_2_unload_userscripts(data): def UI_2_load_aidg_club(data): if koboldai_vars.debug: print("Load aidg.club: {}".format(data)) - importAidgRequest(data) + import_buffer.from_club(data) + # importAidgRequest(data) #==================================================================#