From ad2c2b67222e195d14089a31bb552ece13e93577 Mon Sep 17 00:00:00 2001 From: whjms Date: Wed, 8 Mar 2023 22:51:27 -0500 Subject: [PATCH] move aetherroom import to separate module --- .gitignore | 4 +- aiserver.py | 45 ++++------- importers/aetherroom.py | 57 ++++++++++++++ importers/test_aetherroom.py | 149 +++++++++++++++++++++++++++++++++++ requirements.txt | 6 ++ 5 files changed, 230 insertions(+), 31 deletions(-) create mode 100644 importers/aetherroom.py create mode 100644 importers/test_aetherroom.py diff --git a/.gitignore b/.gitignore index 54a2a91a..d14ddec7 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ Uninstall flask_session accelerate-disk-cache .ipynb_checkpoints +unit_test_report.html # Temporary until HF port !models/RWKV-v4 @@ -36,8 +37,9 @@ models/RWKV-v4/20B_tokenizer.json models/RWKV-v4/src/__pycache__ models/RWKV-v4/models -# Ignore PyCharm project files. +# Ignore PyCharm, VSCode project files. .idea +.vscode # Ignore compiled Python files. *.pyc diff --git a/aiserver.py b/aiserver.py index 2aca109f..0608e50b 100644 --- a/aiserver.py +++ b/aiserver.py @@ -414,40 +414,25 @@ class ImportBuffer: self.world_infos[i][key] = self._replace_placeholders(self.world_infos[i][key]) def from_club(self, club_id): - # Maybe it is a better to parse the NAI Scenario (if available), it has more data - r = requests.get(f"https://aetherroom.club/api/{club_id}") - - if not r.ok: - print(f"[import] Got {r.status_code} on request to club :^(") - message = f"Club responded with {r.status_code}" - if r.status_code == "404": + from importers import aetherroom + import_data: aetherroom.ImportData + try: + import_data = aetherroom.import_scenario(club_id) + except aetherroom.RequestFailed as err: + status = err.status_code + print(f"[import] Got {status} on request to club :^(") + message = f"Club responded with {status}" + if status == "404": message = f"Prompt not found for ID {club_id}" show_error_notification("Error loading prompt", message) return - j = r.json() - - self.prompt = j["promptContent"] - self.memory = j["memory"] - self.authors_note = j["authorsNote"] - self.notes = j["description"] - self.title = j["title"] or "Imported Story" - - self.world_infos = [] - - for wi in j["worldInfos"]: - self.world_infos.append({ - "key_list": wi["keysList"], - "keysecondary": [], - "content": wi["entry"], - "comment": "", - "folder": wi.get("folder", None), - "num": 0, - "init": True, - "selective": wi.get("selective", False), - "constant": wi.get("constant", False), - "uid": None, - }) + self.prompt = import_data.prompt + self.memory = import_data.memory + self.authors_note = import_data.authors_note + self.notes = import_data.notes + self.title = import_data.title + self.world_infos = import_data.world_infos placeholders = self.extract_placeholders(self.prompt) if not placeholders: diff --git a/importers/aetherroom.py b/importers/aetherroom.py new file mode 100644 index 00000000..22c666f1 --- /dev/null +++ b/importers/aetherroom.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +import requests +from typing import List + +BASE_URL = "https://aetherroom.club/api/" + + +@dataclass +class ImportData: + prompt: str + memory: str + authors_note: str + notes: str + title: str + world_infos: List[object] + + +class RequestFailed(Exception): + def __init__(self, status_code: str) -> None: + self.status_code = status_code + super().__init__() + + +def import_scenario(id: int) -> ImportData: + """ + Fetches story info from the provided AetherRoom scenario ID. + """ + # Maybe it is a better to parse the NAI Scenario (if available), it has more data + req = requests.get(f"{BASE_URL}{id}") + if not req.ok: + raise RequestFailed(req.status_code) + + json = req.json() + prompt = json["promptContent"] + memory = json["memory"] + authors_note = json["authorsNote"] + notes = json["description"] + title = json.get("title", "Imported Story") + + world_infos = [] + for info in json["worldinfos"]: + world_infos.append( + { + "key_list": info["keysList"], + "keysecondary": [], + "content": info["entry"], + "comment": "", + "folder": info.get("folder", None), + "num": 0, + "init": True, + "selective": info.get("selective", False), + "constant": info.get("constant", False), + "uid": None, + } + ) + + return ImportData(prompt, memory, authors_note, notes, title, world_infos) diff --git a/importers/test_aetherroom.py b/importers/test_aetherroom.py new file mode 100644 index 00000000..4e3874e9 --- /dev/null +++ b/importers/test_aetherroom.py @@ -0,0 +1,149 @@ +import pytest +import requests_mock + +from importers.aetherroom import ( + ImportData, + RequestFailed, + import_scenario, +) + + +def test_import_scenario_http_error(requests_mock: requests_mock.mocker): + requests_mock.get("https://aetherroom.club/api/1", status_code=404) + with pytest.raises(RequestFailed): + import_scenario(1) + + +def test_import_scenario_success(requests_mock: requests_mock.Mocker): + json = { + "promptContent": "promptContent", + "memory": "memory", + "authorsNote": "authorsNote", + "description": "description", + "title": "title", + "worldinfos": [], + } + requests_mock.get("https://aetherroom.club/api/1", json=json) + + expected_import_data = ImportData( + "promptContent", "memory", "authorsNote", "description", "title", [] + ) + assert import_scenario(1) == expected_import_data + + +def test_import_scenario_no_title(requests_mock: requests_mock.Mocker): + json = { + "promptContent": "promptContent", + "memory": "memory", + "authorsNote": "authorsNote", + "description": "description", + "worldinfos": [], + } + requests_mock.get("https://aetherroom.club/api/1", json=json) + + expected_import_data = ImportData( + "promptContent", "memory", "authorsNote", "description", "Imported Story", [] + ) + assert import_scenario(1) == expected_import_data + + +def test_import_scenario_world_infos(requests_mock: requests_mock.Mocker): + json = { + "promptContent": "promptContent", + "memory": "memory", + "authorsNote": "authorsNote", + "description": "description", + "worldinfos": [ + { + "entry": "Info 1", + "keysList": ["a", "b", "c"], + "folder": "folder", + "selective": True, + "constant": True, + }, + { + "entry": "Info 2", + "keysList": ["d", "e", "f"], + "folder": "folder 2", + "selective": True, + "constant": True, + }, + ], + } + requests_mock.get("https://aetherroom.club/api/1", json=json) + + expected_import_data = ImportData( + "promptContent", + "memory", + "authorsNote", + "description", + "Imported Story", + [ + { + "content": "Info 1", + "key_list": ["a", "b", "c"], + "keysecondary": [], + "comment": "", + "num": 0, + "init": True, + "uid": None, + "folder": "folder", + "selective": True, + "constant": True, + }, + { + "content": "Info 2", + "key_list": ["d", "e", "f"], + "keysecondary": [], + "comment": "", + "num": 0, + "init": True, + "uid": None, + "folder": "folder 2", + "selective": True, + "constant": True, + }, + ], + ) + assert import_scenario(1) == expected_import_data + + +def test_import_scenario_world_info_missing_properties( + requests_mock: requests_mock.Mocker, +): + json = { + "promptContent": "promptContent", + "memory": "memory", + "authorsNote": "authorsNote", + "description": "description", + "worldinfos": [ + { + "entry": "Info 1", + "keysList": ["a", "b", "c"], + } + ], + } + requests_mock.get("https://aetherroom.club/api/1", json=json) + + expected_import_data = ImportData( + "promptContent", + "memory", + "authorsNote", + "description", + "Imported Story", + [ + { + "content": "Info 1", + "key_list": ["a", "b", "c"], + "keysecondary": [], + "comment": "", + "num": 0, + "init": True, + "uid": None, + "folder": None, + "selective": False, + "constant": False, + } + ], + ) + assert import_scenario(1) == expected_import_data diff --git a/requirements.txt b/requirements.txt index cdde17c1..037f1a6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ dnspython==2.2.1 lupa==1.10 markdown bleach==4.1.0 +black sentencepiece protobuf accelerate @@ -29,5 +30,10 @@ flask_compress ijson bitsandbytes ftfy +py==1.11.0 pydub +pytest==7.2.2 +pytest-html==3.2.0 +pytest-metadata==2.0.4 +requests-mock==1.10.0 safetensors \ No newline at end of file