Merge branch 'united' into united

This commit is contained in:
ebolam 2022-02-28 08:37:45 -05:00 committed by GitHub
commit 47d102635e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 309 additions and 239 deletions

View File

@ -65,32 +65,68 @@ class colors:
UNDERLINE = '\033[4m' UNDERLINE = '\033[4m'
# AI models # AI models
modellist = [ mainmenu = [
["Load a model from its directory", "NeoCustom", ""], ["Load a model from its directory", "NeoCustom", ""],
["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", ""], ["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", ""],
["Skein 6B (Hybrid)", "KoboldAI/GPT-J-6B-Skein", "16GB"], ["Skein 6B (Hybrid)", "KoboldAI/GPT-J-6B-Skein", "16GB"],
["Adventure 6B", "KoboldAI/GPT-J-6B-Adventure", "16GB"], ["Adventure 6B", "KoboldAI/GPT-J-6B-Adventure", "16GB"],
["Lit 6B (NSFW)", "hakurei/lit-6B", "16GB"], ["Lit 6B (NSFW)", "hakurei/lit-6B", "16GB"],
["C1 6B (Chatbot)", "hakurei/c1-6B", "16GB"], ["C1 6B (Chatbot)", "hakurei/c1-6B", "16GB"],
["Janeway 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"], ["Janeway Neo 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"],
["Janeway FSD 2.7B (Novel)", "KoboldAI/fairseq-dense-2.7B-Janeway", "8GB"],
["Adventure 2.7B", "KoboldAI/GPT-Neo-2.7B-AID", "8GB"], ["Adventure 2.7B", "KoboldAI/GPT-Neo-2.7B-AID", "8GB"],
["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "8GB"], ["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "8GB"],
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "8GB"], ["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "8GB"],
["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "8GB"], ["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "8GB"],
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "8GB"], ["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "8GB"],
["Untuned GPT-Neo/J", "gptneolist", ""],
["Untuned Fairseq Dense", "fsdlist", ""],
["Untuned XGLM", "xglmlist", ""],
["Untuned GPT2", "gpt2list", ""],
["Online Services", "apilist", ""],
["Read Only (No AI)", "ReadOnly", ""]
]
gptneolist = [
["GPT-J 6B", "EleutherAI/gpt-j-6B", "16GB"], ["GPT-J 6B", "EleutherAI/gpt-j-6B", "16GB"],
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "8GB"], ["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "8GB"],
["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "6GB"], ["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "6GB"],
["Return to Main Menu", "Return", ""],
]
gpt2list = [
["GPT-2 XL", "gpt2-xl", "6GB"], ["GPT-2 XL", "gpt2-xl", "6GB"],
["GPT-2 Large", "gpt2-large", "4GB"], ["GPT-2 Large", "gpt2-large", "4GB"],
["GPT-2 Med", "gpt2-medium", "2GB"], ["GPT-2 Med", "gpt2-medium", "2GB"],
["GPT-2", "gpt2", "2GB"], ["GPT-2", "gpt2", "2GB"],
["Return to Main Menu", "Return", ""],
]
fsdlist = [
["Fairseq Dense 13B", "KoboldAI/fairseq-dense-13B", "32GB"],
["Fairseq Dense 6.7B", "KoboldAI/fairseq-dense-6.7B", "16GB"],
["Fairseq Dense 2.7B", "KoboldAI/fairseq-dense-2.7B", "8GB"],
["Fairseq Dense 1.3B", "KoboldAI/fairseq-dense-1.3B", "6GB"],
["Fairseq Dense 355M", "KoboldAI/fairseq-dense-355M", ""],
["Fairseq Dense 125M", "KoboldAI/fairseq-dense-125M", ""],
["Return to Main Menu", "Return", ""],
]
xglmlist = [
["XGLM 4.5B (Larger Dataset)", "facebook/xglm-4.5B", ""],
["XGLM 7.5B", "facebook/xglm-7.5B", ""],
["XGLM 2.9B", "facebook/xglm-2.9B", ""],
["XGLM 1.7B", "facebook/xglm-1.7B", ""],
["XGLM 564M", "facebook/xglm-564M", ""],
["Return to Main Menu", "Return", ""],
]
apilist = [
["OpenAI API (requires API key)", "OAI", ""], ["OpenAI API (requires API key)", "OAI", ""],
["InferKit API (requires API key)", "InferKit", ""], ["InferKit API (requires API key)", "InferKit", ""],
["KoboldAI Server API (Old Google Colab)", "Colab", ""], ["KoboldAI Server API (Old Google Colab)", "Colab", ""],
["Read Only (No AI)", "ReadOnly", ""] ["Return to Main Menu", "Return", ""],
] ]
# Variables # Variables
class vars: class vars:
lastact = "" # The last action received from the user lastact = "" # The last action received from the user
@ -106,7 +142,7 @@ class vars:
ikgen = 200 # Number of characters for InferKit to generate ikgen = 200 # Number of characters for InferKit to generate
rep_pen = 1.1 # Default generator repetition_penalty rep_pen = 1.1 # Default generator repetition_penalty
rep_pen_slope = 1.0 # Default generator repetition penalty slope rep_pen_slope = 1.0 # Default generator repetition penalty slope
rep_pen_range = 512 # Default generator repetition penalty range rep_pen_range = 1024 # Default generator repetition penalty range
temp = 0.5 # Default generator temperature temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k top_k = 0 # Default generator top_k
@ -134,6 +170,7 @@ class vars:
wifolders_d = {} # Dictionary of World Info folder UID-info pairs wifolders_d = {} # Dictionary of World Info folder UID-info pairs
wifolders_l = [] # List of World Info folder UIDs wifolders_l = [] # List of World Info folder UIDs
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
lua_state = None # Lua state of the Lua scripting system lua_state = None # Lua state of the Lua scripting system
lua_koboldbridge = None # `koboldbridge` from bridge.lua lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua lua_kobold = None # `kobold` from` bridge.lua
@ -217,11 +254,11 @@ utils.vars = vars
#==================================================================# #==================================================================#
# Function to get model selection at startup # Function to get model selection at startup
#==================================================================# #==================================================================#
def getModelSelection(): def getModelSelection(modellist):
print(" # Model VRAM\n =========================================") print(" # Model\t\t\t\t\t\tVRAM\n ========================================================")
i = 1 i = 1
for m in modellist: for m in modellist:
print(" {0} - {1}\t\t{2}".format("{:<2}".format(i), m[0].ljust(15), m[2])) print(" {0} - {1}\t\t\t{2}".format("{:<2}".format(i), m[0].ljust(25), m[2]))
i += 1 i += 1
print(" "); print(" ");
modelsel = 0 modelsel = 0
@ -233,19 +270,26 @@ def getModelSelection():
else: else:
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END)) print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
# If custom model was selected, get the filesystem location and store it # Model Lists
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"): try:
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END)) getModelSelection(eval(vars.model))
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder") except Exception as e:
if(vars.model == "Return"):
getModelSelection(mainmenu)
if(modpath): # If custom model was selected, get the filesystem location and store it
# Save directory to vars if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
vars.custmodpth = modpath print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
else: modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
# Print error and retry model selection
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END)) if(modpath):
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END)) # Save directory to vars
getModelSelection() vars.custmodpth = modpath
else:
# Print error and retry model selection
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection(mainmenu)
#==================================================================# #==================================================================#
# Return all keys in tokenizer dictionary containing char # Return all keys in tokenizer dictionary containing char
@ -413,14 +457,18 @@ def device_config(model):
#==================================================================# #==================================================================#
def loadmodelsettings(): def loadmodelsettings():
try: try:
model_js_config = str(model_config).partition(' ')[2] js = json.loads(str(model_config).partition(' ')[2])
js = json.loads(model_js_config)
except Exception as e: except Exception as e:
try: try:
model_js_config = open(vars.custmodpth + "/config.json", "r") try:
js = json.load(open(vars.custmodpth + "/config.json", "r"))
except Exception as e:
js = json.load(open(vars.custmodpth.replace('/', '_') + "/config.json", "r"))
except Exception as e: except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r") js = {}
js = json.load(model_js_config) if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
vars.modelconfig = js
if("badwordsids" in js): if("badwordsids" in js):
vars.badwordsids = js["badwordsids"] vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js): if("nobreakmodel" in js):
@ -456,6 +504,200 @@ def loadmodelsettings():
if(not vars.gamestarted): if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate vars.authornotetemplate = vars.setauthornotetemplate
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
def loadsettings():
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
# Read file contents into JSON object
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
# Copy file contents to vars
if("apikey" in js):
vars.apikey = js["apikey"]
if("andepth" in js):
vars.andepth = js["andepth"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
vars.max_length = js["max_length"]
if("ikgen" in js):
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("numseqs" in js):
vars.numseqs = js["numseqs"]
if("widepth" in js):
vars.widepth = js["widepth"]
if("useprompt" in js):
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("chatname" in js):
vars.chatname = js["chatname"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("nopromptgen" in js):
vars.nopromptgen = js["nopromptgen"]
if("rngpersist" in js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
if("userscripts" in js):
vars.userscripts = []
for userscript in js["userscripts"]:
if type(userscript) is not str:
continue
userscript = userscript.strip()
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
vars.userscripts.append(userscript)
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
vars.corescript = js["corescript"]
else:
vars.corescript = "default.lua"
file.close()
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================# #==================================================================#
# Startup # Startup
#==================================================================# #==================================================================#
@ -525,7 +767,7 @@ if args.model:
else: else:
print("{0}Welcome to the KoboldAI Server!\nListed RAM is the optimal VRAM and CPU ram can be up to twice the amount.\nMost models can run at less VRAM with reduced max tokens or less layers on the GPU.\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END)) print("{0}Welcome to the KoboldAI Server!\nListed RAM is the optimal VRAM and CPU ram can be up to twice the amount.\nMost models can run at less VRAM with reduced max tokens or less layers on the GPU.\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection() getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU # If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]): if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
@ -568,6 +810,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)") print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo" vars.model_type = "gpt_neo"
loadmodelsettings() loadmodelsettings()
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available() vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
@ -805,7 +1048,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(hasattr(self, "transformer")): if(hasattr(self, "transformer")):
inputs_embeds = self.transformer.wte(input_ids) inputs_embeds = self.transformer.wte(input_ids)
else: else:
inputs_embeds = self.model.embed_tokens(input_ids) * self.model.embed_scale inputs_embeds = self.model.embed_tokens(input_ids)
if(vars.sp is not None): if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device) vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where( inputs_embeds = torch.where(
@ -813,6 +1056,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.sp[shifted_input_ids.clamp(min=0)], vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds, inputs_embeds,
) )
if(not hasattr(self, "transformer")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs) return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward cls.forward = new_causallm_forward
@ -1063,10 +1308,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if not args.colab: if not args.colab:
import shutil import shutil
shutil.rmtree("cache/")
model = model.half() model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_'))) tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):
@ -1186,13 +1431,16 @@ else:
if(vars.model == "Colab"): if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
loadsettings()
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings()
# Load the TPU backend if requested # Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend import tpu_mtj_backend
tpu_mtj_backend.vars = vars tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
@ -1200,10 +1448,14 @@ else:
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer tokenizer = tpu_mtj_backend.tokenizer
else:
loadsettings()
# Set up Flask routes # Set up Flask routes
@app.route('/') @app.route('/')
@ -2295,152 +2547,6 @@ def sendsettings():
if(not frm["id"] in vars.formatoptns): if(not frm["id"] in vars.formatoptns):
vars.formatoptns[frm["id"]] = False; vars.formatoptns[frm["id"]] = False;
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
def loadsettings():
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
# Read file contents into JSON object
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
# Copy file contents to vars
if("apikey" in js):
vars.apikey = js["apikey"]
if("andepth" in js):
vars.andepth = js["andepth"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
vars.max_length = js["max_length"]
if("ikgen" in js):
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("numseqs" in js):
vars.numseqs = js["numseqs"]
if("widepth" in js):
vars.widepth = js["widepth"]
if("useprompt" in js):
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("chatname" in js):
vars.chatname = js["chatname"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("nopromptgen" in js):
vars.nopromptgen = js["nopromptgen"]
if("rngpersist" in js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
if("userscripts" in js):
vars.userscripts = []
for userscript in js["userscripts"]:
if type(userscript) is not str:
continue
userscript = userscript.strip()
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
vars.userscripts.append(userscript)
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
vars.corescript = js["corescript"]
else:
vars.corescript = "default.lua"
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
spRequest(js["softprompt"])
else:
vars.spfilename = ""
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================# #==================================================================#
# Set value of gamesaved # Set value of gamesaved
#==================================================================# #==================================================================#
@ -4490,60 +4596,6 @@ def loadRequest(loadpath, filename=None):
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True) emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END)) print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================# #==================================================================#
# Import an AIDungon game exported with Mimi's tool # Import an AIDungon game exported with Mimi's tool
#==================================================================# #==================================================================#
@ -4888,9 +4940,6 @@ def randomGameRequest(topic, memory=""):
vars.memory = memory vars.memory = memory
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True) emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
# Load desired settings from both the model and the users config file
loadsettings()
# Prevent tokenizer from taking extra time the first time it's used # Prevent tokenizer from taking extra time the first time it's used
def __preempt_tokenizer(): def __preempt_tokenizer():
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
@ -4899,6 +4948,16 @@ def __preempt_tokenizer():
tokenizer.encode(utils.encodenewlines("eunoia")) tokenizer.encode(utils.encodenewlines("eunoia"))
threading.Thread(target=__preempt_tokenizer).start() threading.Thread(target=__preempt_tokenizer).start()
# Load soft prompt specified by the settings file, if applicable
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
spRequest(js["softprompt"])
else:
vars.spfilename = ""
file.close()
# Precompile TPU backend if required # Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)): if(vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens() soft_tokens = tpumtjgetsofttokens()
@ -4930,7 +4989,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
def send_debug(): def send_debug():
if vars.debug: if vars.debug:
debug_info = "" debug_info = ""
for variable in [["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]: for variable in [["Newline Mode", vars.newlinemode], ["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1]) debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True) emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)

View File

@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
return carry return carry
class PenalizingCausalTransformer(CausalTransformer): class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config): def __init__(self, config, **kwargs):
# Initialize # Initialize
super().__init__(config) super().__init__(config, **kwargs)
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
compiling_callback() compiling_callback()
numseqs = numseqs_aux.shape[0] numseqs = numseqs_aux.shape[0]
@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"pe_rotary_dims": 64, "pe_rotary_dims": 64,
"seq": 2048, "seq": 2048,
"cores_per_replica": 8, "cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
} }
params = kwargs params = kwargs
if "compat" in params:
default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm":
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
for param in default_params: for param in default_params:
if param not in params: if param not in params:
params[param] = default_params[param] params[param] = default_params[param]
# Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed # Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count jax.host_count = jax.process_count
jax.host_id = jax.process_index jax.host_id = jax.process_index
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ()) thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap global shard_xmap, batch_xmap
shard_xmap = __shard_xmap() shard_xmap = __shard_xmap()
@ -832,6 +843,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
if not path.endswith("/"): if not path.endswith("/"):
path += "/" path += "/"
network = PenalizingCausalTransformer(params) network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1]) network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica)) network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))