diff --git a/koboldai_settings.py b/koboldai_settings.py index 6e9b9f9d..00fdfb58 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys from typing import Union @@ -783,8 +784,8 @@ class model_settings(settings): process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value) class story_settings(settings): - local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt'] - no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length'] + local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt', 'save_paths'] + no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length', 'save_paths'] settings_name = "story" def __init__(self, socketio, koboldai_vars, tokenizer=None): self.socketio = socketio @@ -868,7 +869,8 @@ class story_settings(settings): self.an_attn_bias = 1 self.chat_style = 0 - + self.save_paths = SavePaths(os.path.join("stories", self.story_name or "Untitled")) + ################### must be at bottom ######################### self.no_save = False #Temporary disable save (doesn't save with the file) @@ -885,11 +887,11 @@ class story_settings(settings): # Disambiguate stories by adding (n) if needed disambiguator = 0 - save_path = os.path.join("stories", save_name) - while os.path.exists(save_path): + self.save_paths.base = os.path.join("stories", save_name) + while os.path.exists(self.save_paths.base): try: # If the stories share a story id, overwrite the existing one. - with open(os.path.join(save_path, "story.json"), "r") as file: + with open(self.save_paths.story, "r") as file: j = json.load(file) if self.story_id == j["story_id"]: break @@ -897,42 +899,18 @@ class story_settings(settings): raise FileNotFoundError("Malformed save file: Missing story.json") disambiguator += 1 - save_path = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else "")) + self.save_paths.base = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else "")) - if not os.path.exists(save_path): + if not os.path.exists(self.save_paths.base): # We are making the story for the first time. Setup the directory structure. - os.mkdir(save_path) - os.mkdir(os.path.join(save_path, "generated_audio")) - os.mkdir(os.path.join(save_path, "generated_images")) + os.mkdir(self.save_paths.base) + os.mkdir(self.save_paths.generated_audio) + os.mkdir(self.save_paths.generated_images) - with open(os.path.join(save_path, "story.json"), "w") as file: + with open(self.save_paths.story, "w") as file: file.write(self.to_json()) self.gamesaved = True - def old_save_story(self): - if not self.no_save: - if self.prompt != "" or self.memory != "" or self.authornote != "" or len(self.actions) > 0 or len(self.worldinfo_v2) > 0: - logger.debug("Saving story from koboldai_vars.story_settings.save_story()") - logger.info("Saving") - save_name = self.story_name if self.story_name != "" else "untitled" - adder = "" - while True: - if os.path.exists("stories/{}{}_v2.json".format(save_name, adder)): - with open("stories/{}{}_v2.json".format(save_name, adder), "r") as f: - temp = json.load(f) - if 'story_id' in temp: - if self.story_id != temp['story_id']: - adder = 0 if adder == "" else adder+1 - else: - break - else: - adder = 0 if adder == "" else adder+1 - else: - break - with open("stories/{}{}_v2.json".format(save_name, adder), "w") as settings_file: - settings_file.write(self.to_json()) - self.gamesaved = True - def save_revision(self): game = json.loads(self.to_json()) del game['revisions'] @@ -1062,6 +1040,8 @@ class story_settings(settings): elif name == 'chatmode' and value == False and self.adventure == False: self.storymode = 0 self.actionmode = 0 + elif name == "story_name": + self.save_paths.base = os.path.join("stories", self.story_name or "Untitled") class user_settings(settings): local_only_variables = ['socketio', 'importjs'] @@ -2276,6 +2256,22 @@ class KoboldWorldInfo(object): def get_used_wi(self): return [x['content'] for x in self.world_info if x['used_in_game']] + +@dataclass +class SavePaths: + base: str + + @property + def story(self) -> str: + return os.path.join(self.base, "story.json") + + @property + def generated_audio(self) -> str: + return os.path.join(self.base, "generated_audio") + + @property + def generated_images(self) -> str: + return os.path.join(self.base, "generated_images") default_rand_range = [0.1, 1, 2] default_creativity_range = [0.8, 1]