mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Nicer api to access save directories and files
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
|
import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -783,8 +784,8 @@ class model_settings(settings):
|
|||||||
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
|
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
|
||||||
|
|
||||||
class story_settings(settings):
|
class story_settings(settings):
|
||||||
local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt']
|
local_only_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'no_save', 'revisions', 'prompt', 'save_paths']
|
||||||
no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length']
|
no_save_variables = ['socketio', 'tokenizer', 'koboldai_vars', 'context', 'no_save', 'prompt_in_ai', 'authornote_length', 'prompt_length', 'memory_length', 'save_paths']
|
||||||
settings_name = "story"
|
settings_name = "story"
|
||||||
def __init__(self, socketio, koboldai_vars, tokenizer=None):
|
def __init__(self, socketio, koboldai_vars, tokenizer=None):
|
||||||
self.socketio = socketio
|
self.socketio = socketio
|
||||||
@@ -868,7 +869,8 @@ class story_settings(settings):
|
|||||||
self.an_attn_bias = 1
|
self.an_attn_bias = 1
|
||||||
self.chat_style = 0
|
self.chat_style = 0
|
||||||
|
|
||||||
|
self.save_paths = SavePaths(os.path.join("stories", self.story_name or "Untitled"))
|
||||||
|
|
||||||
################### must be at bottom #########################
|
################### must be at bottom #########################
|
||||||
self.no_save = False #Temporary disable save (doesn't save with the file)
|
self.no_save = False #Temporary disable save (doesn't save with the file)
|
||||||
|
|
||||||
@@ -885,11 +887,11 @@ class story_settings(settings):
|
|||||||
|
|
||||||
# Disambiguate stories by adding (n) if needed
|
# Disambiguate stories by adding (n) if needed
|
||||||
disambiguator = 0
|
disambiguator = 0
|
||||||
save_path = os.path.join("stories", save_name)
|
self.save_paths.base = os.path.join("stories", save_name)
|
||||||
while os.path.exists(save_path):
|
while os.path.exists(self.save_paths.base):
|
||||||
try:
|
try:
|
||||||
# If the stories share a story id, overwrite the existing one.
|
# If the stories share a story id, overwrite the existing one.
|
||||||
with open(os.path.join(save_path, "story.json"), "r") as file:
|
with open(self.save_paths.story, "r") as file:
|
||||||
j = json.load(file)
|
j = json.load(file)
|
||||||
if self.story_id == j["story_id"]:
|
if self.story_id == j["story_id"]:
|
||||||
break
|
break
|
||||||
@@ -897,42 +899,18 @@ class story_settings(settings):
|
|||||||
raise FileNotFoundError("Malformed save file: Missing story.json")
|
raise FileNotFoundError("Malformed save file: Missing story.json")
|
||||||
|
|
||||||
disambiguator += 1
|
disambiguator += 1
|
||||||
save_path = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else ""))
|
self.save_paths.base = os.path.join("stories", save_name + (f" ({disambiguator})" if disambiguator else ""))
|
||||||
|
|
||||||
if not os.path.exists(save_path):
|
if not os.path.exists(self.save_paths.base):
|
||||||
# We are making the story for the first time. Setup the directory structure.
|
# We are making the story for the first time. Setup the directory structure.
|
||||||
os.mkdir(save_path)
|
os.mkdir(self.save_paths.base)
|
||||||
os.mkdir(os.path.join(save_path, "generated_audio"))
|
os.mkdir(self.save_paths.generated_audio)
|
||||||
os.mkdir(os.path.join(save_path, "generated_images"))
|
os.mkdir(self.save_paths.generated_images)
|
||||||
|
|
||||||
with open(os.path.join(save_path, "story.json"), "w") as file:
|
with open(self.save_paths.story, "w") as file:
|
||||||
file.write(self.to_json())
|
file.write(self.to_json())
|
||||||
self.gamesaved = True
|
self.gamesaved = True
|
||||||
|
|
||||||
def old_save_story(self):
|
|
||||||
if not self.no_save:
|
|
||||||
if self.prompt != "" or self.memory != "" or self.authornote != "" or len(self.actions) > 0 or len(self.worldinfo_v2) > 0:
|
|
||||||
logger.debug("Saving story from koboldai_vars.story_settings.save_story()")
|
|
||||||
logger.info("Saving")
|
|
||||||
save_name = self.story_name if self.story_name != "" else "untitled"
|
|
||||||
adder = ""
|
|
||||||
while True:
|
|
||||||
if os.path.exists("stories/{}{}_v2.json".format(save_name, adder)):
|
|
||||||
with open("stories/{}{}_v2.json".format(save_name, adder), "r") as f:
|
|
||||||
temp = json.load(f)
|
|
||||||
if 'story_id' in temp:
|
|
||||||
if self.story_id != temp['story_id']:
|
|
||||||
adder = 0 if adder == "" else adder+1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
adder = 0 if adder == "" else adder+1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
with open("stories/{}{}_v2.json".format(save_name, adder), "w") as settings_file:
|
|
||||||
settings_file.write(self.to_json())
|
|
||||||
self.gamesaved = True
|
|
||||||
|
|
||||||
def save_revision(self):
|
def save_revision(self):
|
||||||
game = json.loads(self.to_json())
|
game = json.loads(self.to_json())
|
||||||
del game['revisions']
|
del game['revisions']
|
||||||
@@ -1062,6 +1040,8 @@ class story_settings(settings):
|
|||||||
elif name == 'chatmode' and value == False and self.adventure == False:
|
elif name == 'chatmode' and value == False and self.adventure == False:
|
||||||
self.storymode = 0
|
self.storymode = 0
|
||||||
self.actionmode = 0
|
self.actionmode = 0
|
||||||
|
elif name == "story_name":
|
||||||
|
self.save_paths.base = os.path.join("stories", self.story_name or "Untitled")
|
||||||
|
|
||||||
class user_settings(settings):
|
class user_settings(settings):
|
||||||
local_only_variables = ['socketio', 'importjs']
|
local_only_variables = ['socketio', 'importjs']
|
||||||
@@ -2276,6 +2256,22 @@ class KoboldWorldInfo(object):
|
|||||||
|
|
||||||
def get_used_wi(self):
|
def get_used_wi(self):
|
||||||
return [x['content'] for x in self.world_info if x['used_in_game']]
|
return [x['content'] for x in self.world_info if x['used_in_game']]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SavePaths:
|
||||||
|
base: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def story(self) -> str:
|
||||||
|
return os.path.join(self.base, "story.json")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def generated_audio(self) -> str:
|
||||||
|
return os.path.join(self.base, "generated_audio")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def generated_images(self) -> str:
|
||||||
|
return os.path.join(self.base, "generated_images")
|
||||||
|
|
||||||
default_rand_range = [0.1, 1, 2]
|
default_rand_range = [0.1, 1, 2]
|
||||||
default_creativity_range = [0.8, 1]
|
default_creativity_range = [0.8, 1]
|
||||||
|
Reference in New Issue
Block a user