Merge branch 'united' into united
This commit is contained in:
commit
47d102635e
529
aiserver.py
529
aiserver.py
|
@ -65,32 +65,68 @@ class colors:
|
|||
UNDERLINE = '\033[4m'
|
||||
|
||||
# AI models
|
||||
modellist = [
|
||||
mainmenu = [
|
||||
["Load a model from its directory", "NeoCustom", ""],
|
||||
["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", ""],
|
||||
["Skein 6B (Hybrid)", "KoboldAI/GPT-J-6B-Skein", "16GB"],
|
||||
["Adventure 6B", "KoboldAI/GPT-J-6B-Adventure", "16GB"],
|
||||
["Lit 6B (NSFW)", "hakurei/lit-6B", "16GB"],
|
||||
["C1 6B (Chatbot)", "hakurei/c1-6B", "16GB"],
|
||||
["Janeway 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"],
|
||||
["Janeway Neo 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"],
|
||||
["Janeway FSD 2.7B (Novel)", "KoboldAI/fairseq-dense-2.7B-Janeway", "8GB"],
|
||||
["Adventure 2.7B", "KoboldAI/GPT-Neo-2.7B-AID", "8GB"],
|
||||
["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "8GB"],
|
||||
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "8GB"],
|
||||
["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "8GB"],
|
||||
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "8GB"],
|
||||
["Untuned GPT-Neo/J", "gptneolist", ""],
|
||||
["Untuned Fairseq Dense", "fsdlist", ""],
|
||||
["Untuned XGLM", "xglmlist", ""],
|
||||
["Untuned GPT2", "gpt2list", ""],
|
||||
["Online Services", "apilist", ""],
|
||||
["Read Only (No AI)", "ReadOnly", ""]
|
||||
]
|
||||
|
||||
gptneolist = [
|
||||
["GPT-J 6B", "EleutherAI/gpt-j-6B", "16GB"],
|
||||
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "8GB"],
|
||||
["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "6GB"],
|
||||
["Return to Main Menu", "Return", ""],
|
||||
]
|
||||
|
||||
gpt2list = [
|
||||
["GPT-2 XL", "gpt2-xl", "6GB"],
|
||||
["GPT-2 Large", "gpt2-large", "4GB"],
|
||||
["GPT-2 Med", "gpt2-medium", "2GB"],
|
||||
["GPT-2", "gpt2", "2GB"],
|
||||
["Return to Main Menu", "Return", ""],
|
||||
]
|
||||
|
||||
fsdlist = [
|
||||
["Fairseq Dense 13B", "KoboldAI/fairseq-dense-13B", "32GB"],
|
||||
["Fairseq Dense 6.7B", "KoboldAI/fairseq-dense-6.7B", "16GB"],
|
||||
["Fairseq Dense 2.7B", "KoboldAI/fairseq-dense-2.7B", "8GB"],
|
||||
["Fairseq Dense 1.3B", "KoboldAI/fairseq-dense-1.3B", "6GB"],
|
||||
["Fairseq Dense 355M", "KoboldAI/fairseq-dense-355M", ""],
|
||||
["Fairseq Dense 125M", "KoboldAI/fairseq-dense-125M", ""],
|
||||
["Return to Main Menu", "Return", ""],
|
||||
]
|
||||
|
||||
xglmlist = [
|
||||
["XGLM 4.5B (Larger Dataset)", "facebook/xglm-4.5B", ""],
|
||||
["XGLM 7.5B", "facebook/xglm-7.5B", ""],
|
||||
["XGLM 2.9B", "facebook/xglm-2.9B", ""],
|
||||
["XGLM 1.7B", "facebook/xglm-1.7B", ""],
|
||||
["XGLM 564M", "facebook/xglm-564M", ""],
|
||||
["Return to Main Menu", "Return", ""],
|
||||
]
|
||||
|
||||
apilist = [
|
||||
["OpenAI API (requires API key)", "OAI", ""],
|
||||
["InferKit API (requires API key)", "InferKit", ""],
|
||||
["KoboldAI Server API (Old Google Colab)", "Colab", ""],
|
||||
["Read Only (No AI)", "ReadOnly", ""]
|
||||
]
|
||||
|
||||
["Return to Main Menu", "Return", ""],
|
||||
]
|
||||
# Variables
|
||||
class vars:
|
||||
lastact = "" # The last action received from the user
|
||||
|
@ -106,7 +142,7 @@ class vars:
|
|||
ikgen = 200 # Number of characters for InferKit to generate
|
||||
rep_pen = 1.1 # Default generator repetition_penalty
|
||||
rep_pen_slope = 1.0 # Default generator repetition penalty slope
|
||||
rep_pen_range = 512 # Default generator repetition penalty range
|
||||
rep_pen_range = 1024 # Default generator repetition penalty range
|
||||
temp = 0.5 # Default generator temperature
|
||||
top_p = 0.9 # Default generator top_p
|
||||
top_k = 0 # Default generator top_k
|
||||
|
@ -134,6 +170,7 @@ class vars:
|
|||
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
||||
wifolders_l = [] # List of World Info folder UIDs
|
||||
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
||||
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
|
||||
lua_state = None # Lua state of the Lua scripting system
|
||||
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||
lua_kobold = None # `kobold` from` bridge.lua
|
||||
|
@ -217,11 +254,11 @@ utils.vars = vars
|
|||
#==================================================================#
|
||||
# Function to get model selection at startup
|
||||
#==================================================================#
|
||||
def getModelSelection():
|
||||
print(" # Model VRAM\n =========================================")
|
||||
def getModelSelection(modellist):
|
||||
print(" # Model\t\t\t\t\t\tVRAM\n ========================================================")
|
||||
i = 1
|
||||
for m in modellist:
|
||||
print(" {0} - {1}\t\t{2}".format("{:<2}".format(i), m[0].ljust(15), m[2]))
|
||||
print(" {0} - {1}\t\t\t{2}".format("{:<2}".format(i), m[0].ljust(25), m[2]))
|
||||
i += 1
|
||||
print(" ");
|
||||
modelsel = 0
|
||||
|
@ -233,19 +270,26 @@ def getModelSelection():
|
|||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
# If custom model was selected, get the filesystem location and store it
|
||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
|
||||
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
|
||||
# Model Lists
|
||||
try:
|
||||
getModelSelection(eval(vars.model))
|
||||
except Exception as e:
|
||||
if(vars.model == "Return"):
|
||||
getModelSelection(mainmenu)
|
||||
|
||||
# If custom model was selected, get the filesystem location and store it
|
||||
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
|
||||
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
|
||||
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
|
||||
|
||||
if(modpath):
|
||||
# Save directory to vars
|
||||
vars.custmodpth = modpath
|
||||
else:
|
||||
# Print error and retry model selection
|
||||
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
|
||||
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
if(modpath):
|
||||
# Save directory to vars
|
||||
vars.custmodpth = modpath
|
||||
else:
|
||||
# Print error and retry model selection
|
||||
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
|
||||
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection(mainmenu)
|
||||
|
||||
#==================================================================#
|
||||
# Return all keys in tokenizer dictionary containing char
|
||||
|
@ -413,14 +457,18 @@ def device_config(model):
|
|||
#==================================================================#
|
||||
def loadmodelsettings():
|
||||
try:
|
||||
model_js_config = str(model_config).partition(' ')[2]
|
||||
js = json.loads(model_js_config)
|
||||
js = json.loads(str(model_config).partition(' ')[2])
|
||||
except Exception as e:
|
||||
try:
|
||||
model_js_config = open(vars.custmodpth + "/config.json", "r")
|
||||
try:
|
||||
js = json.load(open(vars.custmodpth + "/config.json", "r"))
|
||||
except Exception as e:
|
||||
js = json.load(open(vars.custmodpth.replace('/', '_') + "/config.json", "r"))
|
||||
except Exception as e:
|
||||
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
||||
js = json.load(model_js_config)
|
||||
js = {}
|
||||
if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
|
||||
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
|
||||
vars.modelconfig = js
|
||||
if("badwordsids" in js):
|
||||
vars.badwordsids = js["badwordsids"]
|
||||
if("nobreakmodel" in js):
|
||||
|
@ -456,6 +504,200 @@ def loadmodelsettings():
|
|||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
#==================================================================#
|
||||
# Take settings from vars and write them to client settings file
|
||||
#==================================================================#
|
||||
def savesettings():
|
||||
# Build json to write
|
||||
js = {}
|
||||
js["apikey"] = vars.apikey
|
||||
js["andepth"] = vars.andepth
|
||||
js["temp"] = vars.temp
|
||||
js["top_p"] = vars.top_p
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
js["ikgen"] = vars.ikgen
|
||||
js["formatoptns"] = vars.formatoptns
|
||||
js["numseqs"] = vars.numseqs
|
||||
js["widepth"] = vars.widepth
|
||||
js["useprompt"] = vars.useprompt
|
||||
js["adventure"] = vars.adventure
|
||||
js["chatmode"] = vars.chatmode
|
||||
js["chatname"] = vars.chatname
|
||||
js["dynamicscan"] = vars.dynamicscan
|
||||
js["nopromptgen"] = vars.nopromptgen
|
||||
js["rngpersist"] = vars.rngpersist
|
||||
js["nogenmod"] = vars.nogenmod
|
||||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
js["newlinemode"] = vars.newlinemode
|
||||
|
||||
js["antemplate"] = vars.setauthornotetemplate
|
||||
|
||||
js["userscripts"] = vars.userscripts
|
||||
js["corescript"] = vars.corescript
|
||||
js["softprompt"] = vars.spfilename
|
||||
|
||||
# Write it
|
||||
if not os.path.exists('settings'):
|
||||
os.mkdir('settings')
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
# Don't save settings unless 2 seconds have passed without modification
|
||||
#==================================================================#
|
||||
@debounce(2)
|
||||
def settingschanged():
|
||||
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
||||
savesettings()
|
||||
|
||||
#==================================================================#
|
||||
# Read settings from client file JSON and send to vars
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
if("apikey" in js):
|
||||
vars.apikey = js["apikey"]
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
vars.max_length = js["max_length"]
|
||||
if("ikgen" in js):
|
||||
vars.ikgen = js["ikgen"]
|
||||
if("formatoptns" in js):
|
||||
vars.formatoptns = js["formatoptns"]
|
||||
if("numseqs" in js):
|
||||
vars.numseqs = js["numseqs"]
|
||||
if("widepth" in js):
|
||||
vars.widepth = js["widepth"]
|
||||
if("useprompt" in js):
|
||||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
vars.chatmode = js["chatmode"]
|
||||
if("chatname" in js):
|
||||
vars.chatname = js["chatname"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
if("nopromptgen" in js):
|
||||
vars.nopromptgen = js["nopromptgen"]
|
||||
if("rngpersist" in js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
if("userscripts" in js):
|
||||
vars.userscripts = []
|
||||
for userscript in js["userscripts"]:
|
||||
if type(userscript) is not str:
|
||||
continue
|
||||
userscript = userscript.strip()
|
||||
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
||||
vars.userscripts.append(userscript)
|
||||
|
||||
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
||||
vars.corescript = js["corescript"]
|
||||
else:
|
||||
vars.corescript = "default.lua"
|
||||
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
# Load a soft prompt from a file
|
||||
#==================================================================#
|
||||
def spRequest(filename):
|
||||
vars.spfilename = ""
|
||||
settingschanged()
|
||||
|
||||
if(len(filename) == 0):
|
||||
vars.sp = None
|
||||
vars.sp_length = 0
|
||||
return
|
||||
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
with z.open('meta.json') as f:
|
||||
vars.spmeta = json.load(f)
|
||||
z.close()
|
||||
|
||||
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
||||
tensor = f['tensor.npy']
|
||||
|
||||
# If the tensor is in bfloat16 format, convert it to float32
|
||||
if(tensor.dtype == 'V2'):
|
||||
tensor.dtype = np.uint16
|
||||
tensor = np.uint32(tensor) << 16
|
||||
tensor.dtype = np.float32
|
||||
|
||||
if(tensor.dtype != np.float16):
|
||||
tensor = np.float32(tensor)
|
||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||
|
||||
vars.sp_length = tensor.shape[-2]
|
||||
vars.spmeta["n_tokens"] = vars.sp_length
|
||||
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||
else:
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
||||
vars.spfilename = filename
|
||||
settingschanged()
|
||||
|
||||
#==================================================================#
|
||||
# Startup
|
||||
#==================================================================#
|
||||
|
@ -525,7 +767,7 @@ if args.model:
|
|||
|
||||
else:
|
||||
print("{0}Welcome to the KoboldAI Server!\nListed RAM is the optimal VRAM and CPU ram can be up to twice the amount.\nMost models can run at less VRAM with reduced max tokens or less layers on the GPU.\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
getModelSelection(mainmenu)
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||
|
@ -568,6 +810,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
||||
|
@ -805,7 +1048,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
if(hasattr(self, "transformer")):
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids) * self.model.embed_scale
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if(vars.sp is not None):
|
||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
|
@ -813,6 +1056,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||
inputs_embeds,
|
||||
)
|
||||
if(not hasattr(self, "transformer")):
|
||||
inputs_embeds *= self.model.embed_scale
|
||||
kwargs['inputs_embeds'] = inputs_embeds
|
||||
return old_forward(self, *args, **kwargs)
|
||||
cls.forward = new_causallm_forward
|
||||
|
@ -1063,10 +1308,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
|
||||
if not args.colab:
|
||||
import shutil
|
||||
shutil.rmtree("cache/")
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
shutil.rmtree("cache/")
|
||||
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
|
@ -1186,13 +1431,16 @@ else:
|
|||
if(vars.model == "Colab"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
|
||||
loadsettings()
|
||||
elif(vars.model == "OAI"):
|
||||
from transformers import GPT2TokenizerFast
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
loadsettings()
|
||||
# Load the TPU backend if requested
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
|
||||
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.vars = vars
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
|
@ -1200,10 +1448,14 @@ else:
|
|||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
else:
|
||||
loadsettings()
|
||||
|
||||
# Set up Flask routes
|
||||
@app.route('/')
|
||||
|
@ -2295,152 +2547,6 @@ def sendsettings():
|
|||
if(not frm["id"] in vars.formatoptns):
|
||||
vars.formatoptns[frm["id"]] = False;
|
||||
|
||||
#==================================================================#
|
||||
# Take settings from vars and write them to client settings file
|
||||
#==================================================================#
|
||||
def savesettings():
|
||||
# Build json to write
|
||||
js = {}
|
||||
js["apikey"] = vars.apikey
|
||||
js["andepth"] = vars.andepth
|
||||
js["temp"] = vars.temp
|
||||
js["top_p"] = vars.top_p
|
||||
js["top_k"] = vars.top_k
|
||||
js["tfs"] = vars.tfs
|
||||
js["rep_pen"] = vars.rep_pen
|
||||
js["rep_pen_slope"] = vars.rep_pen_slope
|
||||
js["rep_pen_range"] = vars.rep_pen_range
|
||||
js["genamt"] = vars.genamt
|
||||
js["max_length"] = vars.max_length
|
||||
js["ikgen"] = vars.ikgen
|
||||
js["formatoptns"] = vars.formatoptns
|
||||
js["numseqs"] = vars.numseqs
|
||||
js["widepth"] = vars.widepth
|
||||
js["useprompt"] = vars.useprompt
|
||||
js["adventure"] = vars.adventure
|
||||
js["chatmode"] = vars.chatmode
|
||||
js["chatname"] = vars.chatname
|
||||
js["dynamicscan"] = vars.dynamicscan
|
||||
js["nopromptgen"] = vars.nopromptgen
|
||||
js["rngpersist"] = vars.rngpersist
|
||||
js["nogenmod"] = vars.nogenmod
|
||||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
js["newlinemode"] = vars.newlinemode
|
||||
|
||||
js["antemplate"] = vars.setauthornotetemplate
|
||||
|
||||
js["userscripts"] = vars.userscripts
|
||||
js["corescript"] = vars.corescript
|
||||
js["softprompt"] = vars.spfilename
|
||||
|
||||
# Write it
|
||||
if not os.path.exists('settings'):
|
||||
os.mkdir('settings')
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
#==================================================================#
|
||||
# Read settings from client file JSON and send to vars
|
||||
#==================================================================#
|
||||
def loadsettings():
|
||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
# Read file contents into JSON object
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||
js = json.load(file)
|
||||
|
||||
# Copy file contents to vars
|
||||
if("apikey" in js):
|
||||
vars.apikey = js["apikey"]
|
||||
if("andepth" in js):
|
||||
vars.andepth = js["andepth"]
|
||||
if("temp" in js):
|
||||
vars.temp = js["temp"]
|
||||
if("top_p" in js):
|
||||
vars.top_p = js["top_p"]
|
||||
if("top_k" in js):
|
||||
vars.top_k = js["top_k"]
|
||||
if("tfs" in js):
|
||||
vars.tfs = js["tfs"]
|
||||
if("rep_pen" in js):
|
||||
vars.rep_pen = js["rep_pen"]
|
||||
if("rep_pen_slope" in js):
|
||||
vars.rep_pen_slope = js["rep_pen_slope"]
|
||||
if("rep_pen_range" in js):
|
||||
vars.rep_pen_range = js["rep_pen_range"]
|
||||
if("genamt" in js):
|
||||
vars.genamt = js["genamt"]
|
||||
if("max_length" in js):
|
||||
vars.max_length = js["max_length"]
|
||||
if("ikgen" in js):
|
||||
vars.ikgen = js["ikgen"]
|
||||
if("formatoptns" in js):
|
||||
vars.formatoptns = js["formatoptns"]
|
||||
if("numseqs" in js):
|
||||
vars.numseqs = js["numseqs"]
|
||||
if("widepth" in js):
|
||||
vars.widepth = js["widepth"]
|
||||
if("useprompt" in js):
|
||||
vars.useprompt = js["useprompt"]
|
||||
if("adventure" in js):
|
||||
vars.adventure = js["adventure"]
|
||||
if("chatmode" in js):
|
||||
vars.chatmode = js["chatmode"]
|
||||
if("chatname" in js):
|
||||
vars.chatname = js["chatname"]
|
||||
if("dynamicscan" in js):
|
||||
vars.dynamicscan = js["dynamicscan"]
|
||||
if("nopromptgen" in js):
|
||||
vars.nopromptgen = js["nopromptgen"]
|
||||
if("rngpersist" in js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
if(not vars.gamestarted):
|
||||
vars.authornotetemplate = vars.setauthornotetemplate
|
||||
|
||||
if("userscripts" in js):
|
||||
vars.userscripts = []
|
||||
for userscript in js["userscripts"]:
|
||||
if type(userscript) is not str:
|
||||
continue
|
||||
userscript = userscript.strip()
|
||||
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
|
||||
vars.userscripts.append(userscript)
|
||||
|
||||
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
|
||||
vars.corescript = js["corescript"]
|
||||
else:
|
||||
vars.corescript = "default.lua"
|
||||
|
||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||
spRequest(js["softprompt"])
|
||||
else:
|
||||
vars.spfilename = ""
|
||||
|
||||
file.close()
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Don't save settings unless 2 seconds have passed without modification
|
||||
#==================================================================#
|
||||
@debounce(2)
|
||||
def settingschanged():
|
||||
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
|
||||
savesettings()
|
||||
|
||||
#==================================================================#
|
||||
# Set value of gamesaved
|
||||
#==================================================================#
|
||||
|
@ -4490,60 +4596,6 @@ def loadRequest(loadpath, filename=None):
|
|||
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
|
||||
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
|
||||
|
||||
#==================================================================#
|
||||
# Load a soft prompt from a file
|
||||
#==================================================================#
|
||||
def spRequest(filename):
|
||||
vars.spfilename = ""
|
||||
settingschanged()
|
||||
|
||||
if(len(filename) == 0):
|
||||
vars.sp = None
|
||||
vars.sp_length = 0
|
||||
return
|
||||
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
with z.open('meta.json') as f:
|
||||
vars.spmeta = json.load(f)
|
||||
z.close()
|
||||
|
||||
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
||||
tensor = f['tensor.npy']
|
||||
|
||||
# If the tensor is in bfloat16 format, convert it to float32
|
||||
if(tensor.dtype == 'V2'):
|
||||
tensor.dtype = np.uint16
|
||||
tensor = np.uint32(tensor) << 16
|
||||
tensor.dtype = np.float32
|
||||
|
||||
if(tensor.dtype != np.float16):
|
||||
tensor = np.float32(tensor)
|
||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||
|
||||
vars.sp_length = tensor.shape[-2]
|
||||
vars.spmeta["n_tokens"] = vars.sp_length
|
||||
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||
else:
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
||||
vars.spfilename = filename
|
||||
settingschanged()
|
||||
|
||||
#==================================================================#
|
||||
# Import an AIDungon game exported with Mimi's tool
|
||||
#==================================================================#
|
||||
|
@ -4888,9 +4940,6 @@ def randomGameRequest(topic, memory=""):
|
|||
vars.memory = memory
|
||||
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
|
||||
|
||||
# Load desired settings from both the model and the users config file
|
||||
loadsettings()
|
||||
|
||||
# Prevent tokenizer from taking extra time the first time it's used
|
||||
def __preempt_tokenizer():
|
||||
if("tokenizer" not in globals()):
|
||||
|
@ -4899,6 +4948,16 @@ def __preempt_tokenizer():
|
|||
tokenizer.encode(utils.encodenewlines("eunoia"))
|
||||
threading.Thread(target=__preempt_tokenizer).start()
|
||||
|
||||
# Load soft prompt specified by the settings file, if applicable
|
||||
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
|
||||
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
|
||||
js = json.load(file)
|
||||
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
|
||||
spRequest(js["softprompt"])
|
||||
else:
|
||||
vars.spfilename = ""
|
||||
file.close()
|
||||
|
||||
# Precompile TPU backend if required
|
||||
if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||
soft_tokens = tpumtjgetsofttokens()
|
||||
|
@ -4930,7 +4989,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
|||
def send_debug():
|
||||
if vars.debug:
|
||||
debug_info = ""
|
||||
for variable in [["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]:
|
||||
for variable in [["Newline Mode", vars.newlinemode], ["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
|
||||
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
||||
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
||||
|
||||
|
|
|
@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
|
|||
return carry
|
||||
|
||||
class PenalizingCausalTransformer(CausalTransformer):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, **kwargs):
|
||||
# Initialize
|
||||
super().__init__(config)
|
||||
super().__init__(config, **kwargs)
|
||||
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
||||
compiling_callback()
|
||||
numseqs = numseqs_aux.shape[0]
|
||||
|
@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
"pe_rotary_dims": 64,
|
||||
"seq": 2048,
|
||||
"cores_per_replica": 8,
|
||||
"tokenizer_class": "GPT2TokenizerFast",
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
params = kwargs
|
||||
if "compat" in params:
|
||||
default_params["compat"] = params["compat"]
|
||||
if default_params["compat"] == "fairseq_lm":
|
||||
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
|
||||
for param in default_params:
|
||||
if param not in params:
|
||||
params[param] = default_params[param]
|
||||
|
||||
# Load tokenizer
|
||||
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||
|
||||
# Disable JAX warnings about these two functions having been renamed
|
||||
jax.host_count = jax.process_count
|
||||
jax.host_id = jax.process_index
|
||||
|
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
|
@ -832,6 +843,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
|
||||
network = PenalizingCausalTransformer(params)
|
||||
network = PenalizingCausalTransformer(params, dematerialized=True)
|
||||
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
|
||||
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
|
||||
|
|
Loading…
Reference in New Issue