Nicer api to access save directories and files

This commit is contained in:
somebody
2022-11-28 17:28:46 -06:00
parent e7930101c1
commit 735c4d770c

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
from typing import Union from typing import Union
@@ -783,8 +784,8 @@ class model_settings(settings):
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value) process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
class story_settings(settings): class story_settings(settings):
local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt'] local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt', 'save_paths']
no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length'] no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length', 'save_paths']
settings_name = "story" settings_name = "story"
def __init__(self, socketio, koboldai_vars, tokenizer=None): def __init__(self, socketio, koboldai_vars, tokenizer=None):
self.socketio = socketio self.socketio = socketio
@@ -868,7 +869,8 @@ class story_settings(settings):
self.an_attn_bias = 1 self.an_attn_bias = 1
self.chat_style = 0 self.chat_style = 0
self.save_paths = SavePaths(os.path.join("stories", self.story_name or "Untitled"))
################### must be at bottom ######################### ################### must be at bottom #########################
self.no_save = False #Temporary disable save (doesn't save with the file) self.no_save = False #Temporary disable save (doesn't save with the file)
@@ -885,11 +887,11 @@ class story_settings(settings):
# Disambiguate stories by adding (n) if needed # Disambiguate stories by adding (n) if needed
disambiguator = 0 disambiguator = 0
save_path = os.path.join("stories", save_name) self.save_paths.base = os.path.join("stories", save_name)
while os.path.exists(save_path): while os.path.exists(self.save_paths.base):
try: try:
# If the stories share a story id, overwrite the existing one. # If the stories share a story id, overwrite the existing one.
with open(os.path.join(save_path, "story.json"), "r") as file: with open(self.save_paths.story, "r") as file:
j = json.load(file) j = json.load(file)
if self.story_id == j["story_id"]: if self.story_id == j["story_id"]:
break break
@@ -897,42 +899,18 @@ class story_settings(settings):
raise FileNotFoundError("Malformed save file: Missing story.json") raise FileNotFoundError("Malformed save file: Missing story.json")
disambiguator += 1 disambiguator += 1
save_path = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else "")) self.save_paths.base = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else ""))
if not os.path.exists(save_path): if not os.path.exists(self.save_paths.base):
# We are making the story for the first time. Setup the directory structure. # We are making the story for the first time. Setup the directory structure.
os.mkdir(save_path) os.mkdir(self.save_paths.base)
os.mkdir(os.path.join(save_path, "generated_audio")) os.mkdir(self.save_paths.generated_audio)
os.mkdir(os.path.join(save_path, "generated_images")) os.mkdir(self.save_paths.generated_images)
with open(os.path.join(save_path, "story.json"), "w") as file: with open(self.save_paths.story, "w") as file:
file.write(self.to_json()) file.write(self.to_json())
self.gamesaved = True self.gamesaved = True
def old_save_story(self):
if not self.no_save:
if self.prompt != "" or self.memory != "" or self.authornote != "" or len(self.actions) > 0 or len(self.worldinfo_v2) > 0:
logger.debug("Saving story from koboldai_vars.story_settings.save_story()")
logger.info("Saving")
save_name = self.story_name if self.story_name != "" else "untitled"
adder = ""
while True:
if os.path.exists("stories/{}{}_v2.json".format(save_name, adder)):
with open("stories/{}{}_v2.json".format(save_name, adder), "r") as f:
temp = json.load(f)
if 'story_id' in temp:
if self.story_id != temp['story_id']:
adder = 0 if adder == "" else adder+1
else:
break
else:
adder = 0 if adder == "" else adder+1
else:
break
with open("stories/{}{}_v2.json".format(save_name, adder), "w") as settings_file:
settings_file.write(self.to_json())
self.gamesaved = True
def save_revision(self): def save_revision(self):
game = json.loads(self.to_json()) game = json.loads(self.to_json())
del game['revisions'] del game['revisions']
@@ -1062,6 +1040,8 @@ class story_settings(settings):
elif name == 'chatmode' and value == False and self.adventure == False: elif name == 'chatmode' and value == False and self.adventure == False:
self.storymode = 0 self.storymode = 0
self.actionmode = 0 self.actionmode = 0
elif name == "story_name":
self.save_paths.base = os.path.join("stories", self.story_name or "Untitled")
class user_settings(settings): class user_settings(settings):
local_only_variables = ['socketio', 'importjs'] local_only_variables = ['socketio', 'importjs']
@@ -2276,6 +2256,22 @@ class KoboldWorldInfo(object):
def get_used_wi(self): def get_used_wi(self):
return [x['content'] for x in self.world_info if x['used_in_game']] return [x['content'] for x in self.world_info if x['used_in_game']]
@dataclass
class SavePaths:
base: str
@property
def story(self) -> str:
return os.path.join(self.base, "story.json")
@property
def generated_audio(self) -> str:
return os.path.join(self.base, "generated_audio")
@property
def generated_images(self) -> str:
return os.path.join(self.base, "generated_images")
default_rand_range = [0.1, 1, 2] default_rand_range = [0.1, 1, 2]
default_creativity_range = [0.8, 1] default_creativity_range = [0.8, 1]