mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Soft prompt support (6B Colabs not supported yet)
This commit is contained in:
72
fileops.py
72
fileops.py
@ -1,8 +1,10 @@
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
from os import getcwd, listdir, path
|
||||
from typing import Tuple, Union, Optional
|
||||
import os
|
||||
import json
|
||||
import zipfile
|
||||
|
||||
#==================================================================#
|
||||
# Generic Method for prompting for file path
|
||||
@ -61,6 +63,12 @@ def getdirpath(dir, title):
|
||||
def storypath(name):
|
||||
return path.join(path.dirname(path.realpath(__file__)), "stories", name + ".json")
|
||||
|
||||
#==================================================================#
|
||||
# Returns the path (as a string) to the given soft prompt by its filename
|
||||
#==================================================================#
|
||||
def sppath(filename):
|
||||
return path.join(path.dirname(path.realpath(__file__)), "softprompts", filename)
|
||||
|
||||
#==================================================================#
|
||||
# Returns an array of dicts containing story files in /stories
|
||||
#==================================================================#
|
||||
@ -86,6 +94,70 @@ def getstoryfiles():
|
||||
list.append(ob)
|
||||
return list
|
||||
|
||||
#==================================================================#
|
||||
# Checks if the given soft prompt file is valid
|
||||
#==================================================================#
|
||||
def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]:
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
try:
|
||||
z = zipfile.ZipFile(path.dirname(path.realpath(__file__))+"/softprompts/"+filename)
|
||||
with z.open('tensor.npy') as f:
|
||||
# Read only the header of the npy file, for efficiency reasons
|
||||
version: Tuple[int, int] = np.lib.format.read_magic(f)
|
||||
shape: Tuple[int, int]
|
||||
fortran_order: bool
|
||||
dtype: np.dtype
|
||||
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
|
||||
assert len(shape) == 2
|
||||
except:
|
||||
z.close()
|
||||
return 1, None, None, None, None
|
||||
if dtype not in ('V2', np.float16, np.float32):
|
||||
z.close()
|
||||
return 2, version, shape, fortran_order, dtype
|
||||
if shape[1] != model_dimension:
|
||||
z.close()
|
||||
return 3, version, shape, fortran_order, dtype
|
||||
if shape[0] >= 2048:
|
||||
z.close()
|
||||
return 4, version, shape, fortran_order, dtype
|
||||
return z, version, shape, fortran_order, dtype
|
||||
|
||||
#==================================================================#
|
||||
# Returns an array of dicts containing softprompt files in /softprompts
|
||||
#==================================================================#
|
||||
def getspfiles(model_dimension: int):
|
||||
lst = []
|
||||
os.makedirs(path.dirname(path.realpath(__file__))+"/softprompts", exist_ok=True)
|
||||
for file in listdir(path.dirname(path.realpath(__file__))+"/softprompts"):
|
||||
if not file.endswith(".zip"):
|
||||
continue
|
||||
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
|
||||
if z == 1:
|
||||
print(f"Browser SP loading error: {file} is malformed or not a soft prompt ZIP file.")
|
||||
continue
|
||||
if z == 2:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has unsupported dtype '{dtype.name}'.")
|
||||
continue
|
||||
if z == 3:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
|
||||
continue
|
||||
if z == 4:
|
||||
print(f"Browser SP loading error: {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
|
||||
continue
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
try:
|
||||
with z.open('meta.json') as f:
|
||||
ob = json.load(f)
|
||||
except:
|
||||
ob = {}
|
||||
z.close()
|
||||
ob["filename"] = file
|
||||
lst.append(ob)
|
||||
return lst
|
||||
|
||||
#==================================================================#
|
||||
# Returns True if json file exists with requested save name
|
||||
#==================================================================#
|
||||
|
Reference in New Issue
Block a user