small commentator + WI image overhaul

This commit is contained in:
somebody
2022-12-17 16:23:39 -06:00
parent 74a46982ab
commit 2a556ea346
6 changed files with 89 additions and 152 deletions

View File

@@ -248,6 +248,8 @@ class koboldai_vars(object):
wi_uid = wi['uid']
if wi_uid in used_world_info:
return False
if wi["type"] == "commentator" and not (allowed_wi_entries and wi_uid in allowed_wi_entries):
return False
if allowed_wi_entries is not None and wi_uid not in allowed_wi_entries:
return False
if allowed_wi_folders is not None and wi['folder'] not in allowed_wi_folders:
@@ -906,8 +908,6 @@ class story_settings(settings):
# In percent!!!
self.commentary_chance = 0
# id: {name}
self.commentary_characters = {}
self.commentary_enabled = False
self.save_paths = SavePaths(os.path.join("stories", self.story_name or "Untitled"))
@@ -2099,7 +2099,6 @@ class KoboldWorldInfo(object):
self.world_info = {}
self.world_info_folder = OrderedDict()
self.world_info_folder['root'] = []
self.image_store = {}
self.story_settings = story_settings
def reset(self):
@@ -2325,8 +2324,10 @@ class KoboldWorldInfo(object):
def delete(self, uid):
del self.world_info[uid]
if uid in self.image_store:
del self.image_store[uid]
try:
os.remove(os.path.join(self.koboldai_vars.save_paths.wi_images, str(uid)))
except FileNotFoundError:
pass
for folder in self.world_info_folder:
if uid in self.world_info_folder[folder]:
@@ -2385,25 +2386,34 @@ class KoboldWorldInfo(object):
return {
"folders": {x: self.world_info_folder[x] for x in self.world_info_folder},
"entries": self.world_info,
"images": self.image_store
}
else:
return {
"folders": {x: self.world_info_folder[x] for x in self.world_info_folder if x == folder},
"entries": {x: self.world_info[x] for x in self.world_info if self.world_info[x]['folder'] == folder},
"images": self.image_store
}
def upgrade_entry(self, wi_entry: dict) -> dict:
# If we do not have a type, or it is incorrect, set to WI.
if wi_entry.get("type") not in ["constant", "chatcharacter", "wi"]:
if wi_entry.get("type") not in ["constant", "chatcharacter", "wi", "commentator"]:
wi_entry["type"] = "wi"
if wi_entry["type"] in ["commentator", "constant"]:
wi_entry["constant"] = True
return wi_entry
def load_json(self, data, folder=None):
# Legacy WI images (stored in json)
if "images" in data:
self.image_store = data["images"]
for uid, image_b64 in data["images"].items():
image_b64 = image_b64.split(",")[-1]
image_path = os.path.join(
self.koboldai_vars.save_paths.wi_images,
str(uid)
)
with open(image_path, "wb") as file:
file.write(base64.b64decode(image_b64))
data["entries"] = {k: self.upgrade_entry(v) for k,v in data["entries"].items()}
@@ -2540,6 +2550,14 @@ class KoboldWorldInfo(object):
)
return the_collection
def get_commentators(self) -> List[dict]:
ret = []
for entry in self.world_info.values():
if entry["type"] != "commentator":
continue
ret.append(entry)
return ret
@dataclass
@@ -2552,7 +2570,7 @@ class SavePaths:
self.base,
self.generated_audio,
self.generated_images,
self.commentator_pictures
self.wi_images
]
@property
@@ -2568,8 +2586,8 @@ class SavePaths:
return os.path.join(self.base, "generated_images")
@property
def commentator_pictures(self) -> str:
return os.path.join(self.base, "commentator_pictures")
def wi_images(self) -> str:
return os.path.join(self.base, "wi_images")
default_rand_range = [0.44, 1, 2]
default_creativity_range = [0.5, 1]