KoboldAI-Client/fileops.py

184 lines
7.0 KiB
Python
Raw Normal View History

import tkinter as tk
from tkinter import filedialog
from os import getcwd, listdir, path
from typing import Tuple, Union, Optional
import os
import json
import zipfile
#==================================================================#
# Generic Method for prompting for file path
#==================================================================#
def getsavepath(dir, title, types):
root = tk.Tk()
root.attributes("-topmost", True)
path = tk.filedialog.asksaveasfile(
initialdir=dir,
title=title,
filetypes = types,
defaultextension="*.*"
)
root.destroy()
if(path != "" and path != None):
return path.name
else:
return None
#==================================================================#
# Generic Method for prompting for file path
#==================================================================#
def getloadpath(dir, title, types):
root = tk.Tk()
root.attributes("-topmost", True)
path = tk.filedialog.askopenfilename(
initialdir=dir,
title=title,
filetypes = types
)
root.destroy()
if(path != "" and path != None):
return path
else:
return None
#==================================================================#
# Generic Method for prompting for directory path
#==================================================================#
def getdirpath(dir, title):
root = tk.Tk()
root.attributes("-topmost", True)
path = filedialog.askdirectory(
initialdir=dir,
title=title
)
root.destroy()
if(path != "" and path != None):
return path
else:
return None
#==================================================================#
# Returns the path (as a string) to the given story by its name
#==================================================================#
def storypath(name):
return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json")
#==================================================================#
# Returns the path (as a string) to the given soft prompt by its filename
#==================================================================#
def sppath(filename):
return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename)
#==================================================================#
# Returns an array of dicts containing story files in /stories
#==================================================================#
def getstoryfiles():
list = []
for file in listdir(path.dirname(path.realpath(__file__))+"/stories"):
if file.endswith(".json"):
ob = {}
ob["name"] = file.replace(".json", "")
f = open(path.dirname(path.realpath(__file__))+"/stories/"+file, "r")
try:
js = json.load(f)
except:
print(f"Browser loading error: {file} is malformed or not a JSON file.")
f.close()
continue
f.close()
try:
ob["actions"] = len(js["actions"])
except TypeError:
print(f"Browser loading error: {file} has incorrect format.")
2021-06-26 00:17:07 +02:00
continue
list.append(ob)
return list
#==================================================================#
# Checks if the given soft prompt file is valid
#==================================================================#
def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]:
global np
if 'np' not in globals():
import numpy as np
try:
z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename)
with z.open('tensor.npy') as f:
# Read only the header of the npy file, for efficiency reasons
version: Tuple[int, int] = np.lib.format.read_magic(f)
shape: Tuple[int, int]
fortran_order: bool
dtype: np.dtype
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
assert len(shape) == 2
except:
z.close()
return 1, None, None, None, None
if dtype not in ('V2', np.float16, np.float32):
z.close()
return 2, version, shape, fortran_order, dtype
if shape[1] != model_dimension:
z.close()
return 3, version, shape, fortran_order, dtype
if shape[0] >= 2048:
z.close()
return 4, version, shape, fortran_order, dtype
return z, version, shape, fortran_order, dtype
#==================================================================#
# Returns an array of dicts containing softprompt files in /softprompts
#==================================================================#
def getspfiles(model_dimension: int):
lst = []
os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True)
for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"):
if not file.endswith(".zip"):
continue
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
if z == 1:
print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
continue
if z == 2:
print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
continue
if z == 3:
print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
continue
if z == 4:
print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
continue
assert isinstance(z, zipfile.ZipFile)
try:
with z.open('meta.json') as f:
ob = json.load(f)
except:
ob = {}
z.close()
ob["filename"] = file
lst.append(ob)
return lst
#==================================================================#
# Returns True if json file exists with requested save name
#==================================================================#
def saveexists(name):
return path.exists(storypath(name))
#==================================================================#
# Delete save file by name; returns None if successful, or the exception if not
#==================================================================#
def deletesave(name):
try:
os.remove(storypath(name))
except Exception as e:
return e
#==================================================================#
# Rename save file; returns None if successful, or the exception if not
#==================================================================#
def renamesave(name, new_name):
try:
os.replace(storypath(name), storypath(new_name))
except Exception as e:
return e