mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2024-12-13 00:56:18 +01:00
239 lines
9.6 KiB
Python
239 lines
9.6 KiB
Python
import tkinter as tk
|
|
from tkinter import filedialog
|
|
from os import getcwd, listdir, path
|
|
from typing import Tuple, Union, Optional
|
|
import os
|
|
import json
|
|
import zipfile
|
|
|
|
#==================================================================#
|
|
# Generic Method for prompting for file path
|
|
#==================================================================#
|
|
def getsavepath(dir, title, types):
|
|
root = tk.Tk()
|
|
root.attributes("-topmost", True)
|
|
path = tk.filedialog.asksaveasfile(
|
|
initialdir=dir,
|
|
title=title,
|
|
filetypes = types,
|
|
defaultextension="*.*"
|
|
)
|
|
root.destroy()
|
|
if(path != "" and path != None):
|
|
return path.name
|
|
else:
|
|
return None
|
|
|
|
#==================================================================#
|
|
# Generic Method for prompting for file path
|
|
#==================================================================#
|
|
def getloadpath(dir, title, types):
|
|
root = tk.Tk()
|
|
root.attributes("-topmost", True)
|
|
path = tk.filedialog.askopenfilename(
|
|
initialdir=dir,
|
|
title=title,
|
|
filetypes = types
|
|
)
|
|
root.destroy()
|
|
if(path != "" and path != None):
|
|
return path
|
|
else:
|
|
return None
|
|
|
|
#==================================================================#
|
|
# Generic Method for prompting for directory path
|
|
#==================================================================#
|
|
def getdirpath(dir, title):
|
|
root = tk.Tk()
|
|
root.attributes("-topmost", True)
|
|
path = filedialog.askdirectory(
|
|
initialdir=dir,
|
|
title=title
|
|
)
|
|
root.destroy()
|
|
if(path != "" and path != None):
|
|
return path
|
|
else:
|
|
return None
|
|
|
|
#==================================================================#
|
|
# Returns the path (as a string) to the given story by its name
|
|
#==================================================================#
|
|
def storypath(name):
|
|
return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json")
|
|
|
|
#==================================================================#
|
|
# Returns the path (as a string) to the given soft prompt by its filename
|
|
#==================================================================#
|
|
def sppath(filename):
|
|
return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename)
|
|
|
|
#==================================================================#
|
|
# Returns the path (as a string) to the given username by its filename
|
|
#==================================================================#
|
|
def uspath(filename):
|
|
return path.join(path.dirname(path.realpath(__file__)), "userscripts", filename)
|
|
|
|
#==================================================================#
|
|
# Returns an array of dicts containing story files in /stories
|
|
#==================================================================#
|
|
def getstoryfiles():
|
|
list = []
|
|
for file in listdir(path.dirname(path.realpath(__file__))+"/stories"):
|
|
if file.endswith(".json"):
|
|
ob = {}
|
|
ob["name"] = file.replace(".json", "")
|
|
f = open(path.dirname(path.realpath(__file__))+"/stories/"+file, "r")
|
|
try:
|
|
js = json.load(f)
|
|
except:
|
|
print(f"Browser loading error: {file} is malformed or not a JSON file.")
|
|
f.close()
|
|
continue
|
|
f.close()
|
|
try:
|
|
ob["actions"] = len(js["actions"])
|
|
except TypeError:
|
|
print(f"Browser loading error: {file} has incorrect format.")
|
|
continue
|
|
list.append(ob)
|
|
return list
|
|
|
|
#==================================================================#
|
|
# Checks if the given soft prompt file is valid
|
|
#==================================================================#
|
|
def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]:
|
|
global np
|
|
if 'np' not in globals():
|
|
import numpy as np
|
|
try:
|
|
z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename)
|
|
with z.open('tensor.npy') as f:
|
|
# Read only the header of the npy file, for efficiency reasons
|
|
version: Tuple[int, int] = np.lib.format.read_magic(f)
|
|
shape: Tuple[int, int]
|
|
fortran_order: bool
|
|
dtype: np.dtype
|
|
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
|
|
assert len(shape) == 2
|
|
except:
|
|
z.close()
|
|
return 1, None, None, None, None
|
|
if dtype not in ('V2', np.float16, np.float32):
|
|
z.close()
|
|
return 2, version, shape, fortran_order, dtype
|
|
if shape[1] != model_dimension:
|
|
z.close()
|
|
return 3, version, shape, fortran_order, dtype
|
|
if shape[0] >= 2048:
|
|
z.close()
|
|
return 4, version, shape, fortran_order, dtype
|
|
return z, version, shape, fortran_order, dtype
|
|
|
|
#==================================================================#
|
|
# Returns an array of dicts containing softprompt files in /softprompts
|
|
#==================================================================#
|
|
def getspfiles(model_dimension: int):
|
|
lst = []
|
|
os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True)
|
|
for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"):
|
|
if not file.endswith(".zip"):
|
|
continue
|
|
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
|
|
if z == 1:
|
|
print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
|
|
continue
|
|
if z == 2:
|
|
print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
|
|
continue
|
|
if z == 3:
|
|
print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
|
|
continue
|
|
if z == 4:
|
|
print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
|
|
continue
|
|
assert isinstance(z, zipfile.ZipFile)
|
|
try:
|
|
with z.open('meta.json') as f:
|
|
ob = json.load(f)
|
|
except:
|
|
ob = {}
|
|
z.close()
|
|
ob["filename"] = file
|
|
ob["n_tokens"] = shape[-2]
|
|
lst.append(ob)
|
|
return lst
|
|
|
|
#==================================================================#
|
|
# Returns an array of dicts containing userscript files in /userscripts
|
|
#==================================================================#
|
|
def getusfiles(long_desc=False):
|
|
lst = []
|
|
os.makedirs(path.dirname(path.realpath(__file__))+"/userscripts", exist_ok=True)
|
|
for file in listdir(path.dirname(path.realpath(__file__))+"/userscripts"):
|
|
if file.endswith(".lua"):
|
|
ob = {}
|
|
ob["filename"] = file
|
|
description = []
|
|
multiline = False
|
|
with open(uspath(file)) as f:
|
|
ob["modulename"] = f.readline().strip().replace("\033", "")
|
|
if ob["modulename"][:2] != "--":
|
|
ob["modulename"] = file
|
|
else:
|
|
ob["modulename"] = ob["modulename"][2:]
|
|
if ob["modulename"][:2] == "[[":
|
|
ob["modulename"] = ob["modulename"][2:]
|
|
multiline = True
|
|
ob["modulename"] = ob["modulename"].lstrip("-").strip()
|
|
for line in f:
|
|
line = line.strip().replace("\033", "")
|
|
if multiline:
|
|
index = line.find("]]")
|
|
if index > -1:
|
|
description.append(line[:index])
|
|
if index != len(line) - 2:
|
|
break
|
|
multiline = False
|
|
else:
|
|
description.append(line)
|
|
else:
|
|
if line[:2] != "--":
|
|
break
|
|
line = line[2:]
|
|
if line[:2] == "[[":
|
|
multiline = True
|
|
line = line[2:]
|
|
description.append(line.strip())
|
|
ob["description"] = "\n".join(description)
|
|
if not long_desc:
|
|
if len(ob["description"]) > 250:
|
|
ob["description"] = ob["description"][:247] + "..."
|
|
lst.append(ob)
|
|
return lst
|
|
|
|
#==================================================================#
|
|
# Returns True if json file exists with requested save name
|
|
#==================================================================#
|
|
def saveexists(name):
|
|
return path.exists(storypath(name))
|
|
|
|
#==================================================================#
|
|
# Delete save file by name; returns None if successful, or the exception if not
|
|
#==================================================================#
|
|
def deletesave(name):
|
|
try:
|
|
os.remove(storypath(name))
|
|
except Exception as e:
|
|
return e
|
|
|
|
#==================================================================#
|
|
# Rename save file; returns None if successful, or the exception if not
|
|
#==================================================================#
|
|
def renamesave(name, new_name):
|
|
try:
|
|
os.replace(storypath(name), storypath(new_name))
|
|
except Exception as e:
|
|
return e
|