2021-05-22 11:28:40 +02:00
from os import getcwd , listdir , path
2021-10-22 20:18:10 +02:00
from typing import Tuple , Union , Optional
2021-09-01 00:22:30 +02:00
import os
2021-05-22 11:28:40 +02:00
import json
2021-10-22 20:18:10 +02:00
import zipfile
2022-09-12 12:00:30 +02:00
from logger import logger
2021-05-07 20:32:10 +02:00
#==================================================================#
# Generic Method for prompting for file path
#==================================================================#
def getsavepath ( dir , title , types ) :
2022-03-06 14:12:01 +01:00
import tkinter as tk
from tkinter import filedialog
2021-05-07 20:32:10 +02:00
root = tk . Tk ( )
root . attributes ( " -topmost " , True )
path = tk . filedialog . asksaveasfile (
initialdir = dir ,
title = title ,
2021-05-16 01:29:41 +02:00
filetypes = types ,
defaultextension = " *.* "
2021-05-07 20:32:10 +02:00
)
root . destroy ( )
if ( path != " " and path != None ) :
return path . name
else :
return None
#==================================================================#
# Generic Method for prompting for file path
#==================================================================#
def getloadpath ( dir , title , types ) :
2022-03-06 14:12:01 +01:00
import tkinter as tk
from tkinter import filedialog
2021-05-07 20:32:10 +02:00
root = tk . Tk ( )
root . attributes ( " -topmost " , True )
path = tk . filedialog . askopenfilename (
initialdir = dir ,
title = title ,
filetypes = types
)
root . destroy ( )
if ( path != " " and path != None ) :
return path
else :
return None
#==================================================================#
# Generic Method for prompting for directory path
#==================================================================#
def getdirpath ( dir , title ) :
2022-03-06 14:12:01 +01:00
import tkinter as tk
from tkinter import filedialog
2021-05-07 20:32:10 +02:00
root = tk . Tk ( )
root . attributes ( " -topmost " , True )
path = filedialog . askdirectory (
initialdir = dir ,
title = title
)
root . destroy ( )
if ( path != " " and path != None ) :
return path
else :
2021-05-22 11:28:40 +02:00
return None
2021-09-01 00:22:30 +02:00
#==================================================================#
# Returns the path (as a string) to the given story by its name
#==================================================================#
def storypath ( name ) :
2022-03-13 07:22:11 +01:00
return path . join ( " stories " , name + " .json " )
2021-09-01 00:22:30 +02:00
2021-10-22 20:18:10 +02:00
#==================================================================#
# Returns the path (as a string) to the given soft prompt by its filename
#==================================================================#
def sppath ( filename ) :
2022-03-13 07:22:11 +01:00
return path . join ( " softprompts " , filename )
2021-10-22 20:18:10 +02:00
2021-12-13 07:03:26 +01:00
#==================================================================#
# Returns the path (as a string) to the given username by its filename
#==================================================================#
def uspath ( filename ) :
2022-03-13 07:22:11 +01:00
return path . join ( " userscripts " , filename )
2021-12-13 07:03:26 +01:00
2021-05-22 11:28:40 +02:00
#==================================================================#
# Returns an array of dicts containing story files in /stories
#==================================================================#
def getstoryfiles ( ) :
list = [ ]
2022-03-13 07:22:11 +01:00
for file in listdir ( " stories " ) :
2022-10-23 19:03:18 +02:00
if file . endswith ( " .json " ) and not file . endswith ( " .v2.json " ) :
2021-05-22 11:28:40 +02:00
ob = { }
ob [ " name " ] = file . replace ( " .json " , " " )
2022-03-13 07:22:11 +01:00
f = open ( " stories/ " + file , " r " )
2021-06-26 00:02:19 +02:00
try :
js = json . load ( f )
except :
2021-08-25 01:29:40 +02:00
print ( f " Browser loading error: { file } is malformed or not a JSON file. " )
2021-06-26 00:02:19 +02:00
f . close ( )
continue
2021-05-22 11:28:40 +02:00
f . close ( )
2021-06-26 00:02:19 +02:00
try :
ob [ " actions " ] = len ( js [ " actions " ] )
except TypeError :
2021-08-25 01:29:40 +02:00
print ( f " Browser loading error: { file } has incorrect format. " )
2021-06-26 00:17:07 +02:00
continue
2021-05-22 11:28:40 +02:00
list . append ( ob )
return list
2021-10-22 20:18:10 +02:00
#==================================================================#
# Checks if the given soft prompt file is valid
#==================================================================#
def checksp ( filename : str , model_dimension : int ) - > Tuple [ Union [ zipfile . ZipFile , int ] , Optional [ Tuple [ int , int ] ] , Optional [ Tuple [ int , int ] ] , Optional [ bool ] , Optional [ ' np.dtype ' ] ] :
global np
if ' np ' not in globals ( ) :
import numpy as np
try :
2022-03-13 07:22:11 +01:00
z = zipfile . ZipFile ( " softprompts/ " + filename )
2021-10-22 20:18:10 +02:00
with z . open ( ' tensor.npy ' ) as f :
# Read only the header of the npy file, for efficiency reasons
version : Tuple [ int , int ] = np . lib . format . read_magic ( f )
shape : Tuple [ int , int ]
fortran_order : bool
dtype : np . dtype
shape , fortran_order , dtype = np . lib . format . _read_array_header ( f , version )
assert len ( shape ) == 2
except :
2022-04-12 21:59:05 +02:00
try :
z . close ( )
except UnboundLocalError :
pass
2021-10-22 20:18:10 +02:00
return 1 , None , None , None , None
if dtype not in ( ' V2 ' , np . float16 , np . float32 ) :
z . close ( )
return 2 , version , shape , fortran_order , dtype
if shape [ 1 ] != model_dimension :
z . close ( )
return 3 , version , shape , fortran_order , dtype
if shape [ 0 ] > = 2048 :
z . close ( )
return 4 , version , shape , fortran_order , dtype
return z , version , shape , fortran_order , dtype
#==================================================================#
# Returns an array of dicts containing softprompt files in /softprompts
#==================================================================#
def getspfiles ( model_dimension : int ) :
lst = [ ]
2022-03-13 07:22:11 +01:00
os . makedirs ( " softprompts " , exist_ok = True )
for file in listdir ( " softprompts " ) :
2021-10-22 20:18:10 +02:00
if not file . endswith ( " .zip " ) :
continue
z , version , shape , fortran_order , dtype = checksp ( file , model_dimension )
if z == 1 :
2022-09-12 12:00:30 +02:00
logger . warning ( f " Softprompt { file } is malformed or not a soft prompt ZIP file. " )
2021-10-22 20:18:10 +02:00
continue
if z == 2 :
2022-09-12 12:00:30 +02:00
logger . warning ( f " Softprompt { file } tensor.npy has unsupported dtype ' { dtype . name } ' . " )
2021-10-22 20:18:10 +02:00
continue
if z == 3 :
2022-09-12 12:00:30 +02:00
logger . debug ( f " Softprompt { file } tensor.npy has model dimension { shape [ 1 ] } which does not match your model ' s model dimension of { model_dimension } . This usually means this soft prompt is not compatible with your model. " )
2021-10-22 20:18:10 +02:00
continue
if z == 4 :
2022-09-12 12:00:30 +02:00
logger . warning ( f " Softprompt { file } tensor.npy has { shape [ 0 ] } tokens but it is supposed to have less than 2048 tokens. " )
2021-10-22 20:18:10 +02:00
continue
assert isinstance ( z , zipfile . ZipFile )
try :
with z . open ( ' meta.json ' ) as f :
ob = json . load ( f )
except :
ob = { }
z . close ( )
ob [ " filename " ] = file
2022-01-18 22:30:09 +01:00
ob [ " n_tokens " ] = shape [ - 2 ]
2021-10-22 20:18:10 +02:00
lst . append ( ob )
return lst
2021-12-13 07:03:26 +01:00
#==================================================================#
# Returns an array of dicts containing userscript files in /userscripts
#==================================================================#
def getusfiles ( long_desc = False ) :
lst = [ ]
2022-03-13 07:22:11 +01:00
os . makedirs ( " userscripts " , exist_ok = True )
for file in listdir ( " userscripts " ) :
2021-12-13 07:03:26 +01:00
if file . endswith ( " .lua " ) :
ob = { }
ob [ " filename " ] = file
description = [ ]
multiline = False
with open ( uspath ( file ) ) as f :
ob [ " modulename " ] = f . readline ( ) . strip ( ) . replace ( " \033 " , " " )
if ob [ " modulename " ] [ : 2 ] != " -- " :
ob [ " modulename " ] = file
else :
ob [ " modulename " ] = ob [ " modulename " ] [ 2 : ]
if ob [ " modulename " ] [ : 2 ] == " [[ " :
ob [ " modulename " ] = ob [ " modulename " ] [ 2 : ]
multiline = True
ob [ " modulename " ] = ob [ " modulename " ] . lstrip ( " - " ) . strip ( )
for line in f :
line = line . strip ( ) . replace ( " \033 " , " " )
if multiline :
index = line . find ( " ]] " )
if index > - 1 :
description . append ( line [ : index ] )
if index != len ( line ) - 2 :
break
multiline = False
else :
description . append ( line )
else :
if line [ : 2 ] != " -- " :
break
line = line [ 2 : ]
if line [ : 2 ] == " [[ " :
multiline = True
line = line [ 2 : ]
description . append ( line . strip ( ) )
ob [ " description " ] = " \n " . join ( description )
if not long_desc :
if len ( ob [ " description " ] ) > 250 :
ob [ " description " ] = ob [ " description " ] [ : 247 ] + " ... "
lst . append ( ob )
return lst
2021-05-22 11:28:40 +02:00
#==================================================================#
# Returns True if json file exists with requested save name
#==================================================================#
def saveexists ( name ) :
2021-09-01 00:22:30 +02:00
return path . exists ( storypath ( name ) )
#==================================================================#
# Delete save file by name; returns None if successful, or the exception if not
#==================================================================#
def deletesave ( name ) :
try :
os . remove ( storypath ( name ) )
except Exception as e :
return e
#==================================================================#
# Rename save file; returns None if successful, or the exception if not
#==================================================================#
def renamesave ( name , new_name ) :
try :
os . replace ( storypath ( name ) , storypath ( new_name ) )
except Exception as e :
return e