Merge branch 'united' into united

This commit is contained in:
ebolam 2022-02-28 08:37:45 -05:00 committed by GitHub
commit 47d102635e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 309 additions and 239 deletions

View File

@ -65,32 +65,68 @@ class colors:
UNDERLINE = '\033[4m'
# AI models
modellist = [
mainmenu = [
["Load a model from its directory", "NeoCustom", ""],
["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", ""],
["Skein 6B (Hybrid)", "KoboldAI/GPT-J-6B-Skein", "16GB"],
["Adventure 6B", "KoboldAI/GPT-J-6B-Adventure", "16GB"],
["Lit 6B (NSFW)", "hakurei/lit-6B", "16GB"],
["C1 6B (Chatbot)", "hakurei/c1-6B", "16GB"],
["Janeway 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"],
["Janeway Neo 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB"],
["Janeway FSD 2.7B (Novel)", "KoboldAI/fairseq-dense-2.7B-Janeway", "8GB"],
["Adventure 2.7B", "KoboldAI/GPT-Neo-2.7B-AID", "8GB"],
["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "8GB"],
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "8GB"],
["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "8GB"],
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "8GB"],
["Untuned GPT-Neo/J", "gptneolist", ""],
["Untuned Fairseq Dense", "fsdlist", ""],
["Untuned XGLM", "xglmlist", ""],
["Untuned GPT2", "gpt2list", ""],
["Online Services", "apilist", ""],
["Read Only (No AI)", "ReadOnly", ""]
]
gptneolist = [
["GPT-J 6B", "EleutherAI/gpt-j-6B", "16GB"],
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "8GB"],
["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "6GB"],
["Return to Main Menu", "Return", ""],
]
gpt2list = [
["GPT-2 XL", "gpt2-xl", "6GB"],
["GPT-2 Large", "gpt2-large", "4GB"],
["GPT-2 Med", "gpt2-medium", "2GB"],
["GPT-2", "gpt2", "2GB"],
["Return to Main Menu", "Return", ""],
]
fsdlist = [
["Fairseq Dense 13B", "KoboldAI/fairseq-dense-13B", "32GB"],
["Fairseq Dense 6.7B", "KoboldAI/fairseq-dense-6.7B", "16GB"],
["Fairseq Dense 2.7B", "KoboldAI/fairseq-dense-2.7B", "8GB"],
["Fairseq Dense 1.3B", "KoboldAI/fairseq-dense-1.3B", "6GB"],
["Fairseq Dense 355M", "KoboldAI/fairseq-dense-355M", ""],
["Fairseq Dense 125M", "KoboldAI/fairseq-dense-125M", ""],
["Return to Main Menu", "Return", ""],
]
xglmlist = [
["XGLM 4.5B (Larger Dataset)", "facebook/xglm-4.5B", ""],
["XGLM 7.5B", "facebook/xglm-7.5B", ""],
["XGLM 2.9B", "facebook/xglm-2.9B", ""],
["XGLM 1.7B", "facebook/xglm-1.7B", ""],
["XGLM 564M", "facebook/xglm-564M", ""],
["Return to Main Menu", "Return", ""],
]
apilist = [
["OpenAI API (requires API key)", "OAI", ""],
["InferKit API (requires API key)", "InferKit", ""],
["KoboldAI Server API (Old Google Colab)", "Colab", ""],
["Read Only (No AI)", "ReadOnly", ""]
]
["Return to Main Menu", "Return", ""],
]
# Variables
class vars:
lastact = "" # The last action received from the user
@ -106,7 +142,7 @@ class vars:
ikgen = 200 # Number of characters for InferKit to generate
rep_pen = 1.1 # Default generator repetition_penalty
rep_pen_slope = 1.0 # Default generator repetition penalty slope
rep_pen_range = 512 # Default generator repetition penalty range
rep_pen_range = 1024 # Default generator repetition penalty range
temp = 0.5 # Default generator temperature
top_p = 0.9 # Default generator top_p
top_k = 0 # Default generator top_k
@ -134,6 +170,7 @@ class vars:
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
wifolders_l = [] # List of World Info folder UIDs
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
lua_state = None # Lua state of the Lua scripting system
lua_koboldbridge = None # `koboldbridge` from bridge.lua
lua_kobold = None # `kobold` from` bridge.lua
@ -217,11 +254,11 @@ utils.vars = vars
#==================================================================#
# Function to get model selection at startup
#==================================================================#
def getModelSelection():
print(" # Model VRAM\n =========================================")
def getModelSelection(modellist):
print(" # Model\t\t\t\t\t\tVRAM\n ========================================================")
i = 1
for m in modellist:
print(" {0} - {1}\t\t{2}".format("{:<2}".format(i), m[0].ljust(15), m[2]))
print(" {0} - {1}\t\t\t{2}".format("{:<2}".format(i), m[0].ljust(25), m[2]))
i += 1
print(" ");
modelsel = 0
@ -233,19 +270,26 @@ def getModelSelection():
else:
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
# If custom model was selected, get the filesystem location and store it
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
# Model Lists
try:
getModelSelection(eval(vars.model))
except Exception as e:
if(vars.model == "Return"):
getModelSelection(mainmenu)
# If custom model was selected, get the filesystem location and store it
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
if(modpath):
# Save directory to vars
vars.custmodpth = modpath
else:
# Print error and retry model selection
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection()
if(modpath):
# Save directory to vars
vars.custmodpth = modpath
else:
# Print error and retry model selection
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection(mainmenu)
#==================================================================#
# Return all keys in tokenizer dictionary containing char
@ -413,14 +457,18 @@ def device_config(model):
#==================================================================#
def loadmodelsettings():
try:
model_js_config = str(model_config).partition(' ')[2]
js = json.loads(model_js_config)
js = json.loads(str(model_config).partition(' ')[2])
except Exception as e:
try:
model_js_config = open(vars.custmodpth + "/config.json", "r")
try:
js = json.load(open(vars.custmodpth + "/config.json", "r"))
except Exception as e:
js = json.load(open(vars.custmodpth.replace('/', '_') + "/config.json", "r"))
except Exception as e:
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
js = json.load(model_js_config)
js = {}
if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
vars.modelconfig = js
if("badwordsids" in js):
vars.badwordsids = js["badwordsids"]
if("nobreakmodel" in js):
@ -456,6 +504,200 @@ def loadmodelsettings():
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
def loadsettings():
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
# Read file contents into JSON object
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
# Copy file contents to vars
if("apikey" in js):
vars.apikey = js["apikey"]
if("andepth" in js):
vars.andepth = js["andepth"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
vars.max_length = js["max_length"]
if("ikgen" in js):
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("numseqs" in js):
vars.numseqs = js["numseqs"]
if("widepth" in js):
vars.widepth = js["widepth"]
if("useprompt" in js):
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("chatname" in js):
vars.chatname = js["chatname"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("nopromptgen" in js):
vars.nopromptgen = js["nopromptgen"]
if("rngpersist" in js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
if("userscripts" in js):
vars.userscripts = []
for userscript in js["userscripts"]:
if type(userscript) is not str:
continue
userscript = userscript.strip()
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
vars.userscripts.append(userscript)
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
vars.corescript = js["corescript"]
else:
vars.corescript = "default.lua"
file.close()
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================#
# Startup
#==================================================================#
@ -525,7 +767,7 @@ if args.model:
else:
print("{0}Welcome to the KoboldAI Server!\nListed RAM is the optimal VRAM and CPU ram can be up to twice the amount.\nMost models can run at less VRAM with reduced max tokens or less layers on the GPU.\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
getModelSelection()
getModelSelection(mainmenu)
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
@ -568,6 +810,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
vars.model_type = "gpt_neo"
loadmodelsettings()
loadsettings()
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
@ -805,7 +1048,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if(hasattr(self, "transformer")):
inputs_embeds = self.transformer.wte(input_ids)
else:
inputs_embeds = self.model.embed_tokens(input_ids) * self.model.embed_scale
inputs_embeds = self.model.embed_tokens(input_ids)
if(vars.sp is not None):
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
inputs_embeds = torch.where(
@ -813,6 +1056,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.sp[shifted_input_ids.clamp(min=0)],
inputs_embeds,
)
if(not hasattr(self, "transformer")):
inputs_embeds *= self.model.embed_scale
kwargs['inputs_embeds'] = inputs_embeds
return old_forward(self, *args, **kwargs)
cls.forward = new_causallm_forward
@ -1063,10 +1308,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
if not args.colab:
import shutil
shutil.rmtree("cache/")
model = model.half()
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
shutil.rmtree("cache/")
if(vars.hascuda):
if(vars.usegpu):
@ -1186,13 +1431,16 @@ else:
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-2.7B", cache_dir="cache/")
loadsettings()
elif(vars.model == "OAI"):
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
loadsettings()
# Load the TPU backend if requested
elif(vars.model == "TPUMeshTransformerGPTJ"):
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
import tpu_mtj_backend
tpu_mtj_backend.vars = vars
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
@ -1200,10 +1448,14 @@ else:
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
loadmodelsettings()
loadsettings()
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
else:
loadsettings()
# Set up Flask routes
@app.route('/')
@ -2295,152 +2547,6 @@ def sendsettings():
if(not frm["id"] in vars.formatoptns):
vars.formatoptns[frm["id"]] = False;
#==================================================================#
# Take settings from vars and write them to client settings file
#==================================================================#
def savesettings():
# Build json to write
js = {}
js["apikey"] = vars.apikey
js["andepth"] = vars.andepth
js["temp"] = vars.temp
js["top_p"] = vars.top_p
js["top_k"] = vars.top_k
js["tfs"] = vars.tfs
js["rep_pen"] = vars.rep_pen
js["rep_pen_slope"] = vars.rep_pen_slope
js["rep_pen_range"] = vars.rep_pen_range
js["genamt"] = vars.genamt
js["max_length"] = vars.max_length
js["ikgen"] = vars.ikgen
js["formatoptns"] = vars.formatoptns
js["numseqs"] = vars.numseqs
js["widepth"] = vars.widepth
js["useprompt"] = vars.useprompt
js["adventure"] = vars.adventure
js["chatmode"] = vars.chatmode
js["chatname"] = vars.chatname
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
js["corescript"] = vars.corescript
js["softprompt"] = vars.spfilename
# Write it
if not os.path.exists('settings'):
os.mkdir('settings')
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "w")
try:
file.write(json.dumps(js, indent=3))
finally:
file.close()
#==================================================================#
# Read settings from client file JSON and send to vars
#==================================================================#
def loadsettings():
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
# Read file contents into JSON object
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
# Copy file contents to vars
if("apikey" in js):
vars.apikey = js["apikey"]
if("andepth" in js):
vars.andepth = js["andepth"]
if("temp" in js):
vars.temp = js["temp"]
if("top_p" in js):
vars.top_p = js["top_p"]
if("top_k" in js):
vars.top_k = js["top_k"]
if("tfs" in js):
vars.tfs = js["tfs"]
if("rep_pen" in js):
vars.rep_pen = js["rep_pen"]
if("rep_pen_slope" in js):
vars.rep_pen_slope = js["rep_pen_slope"]
if("rep_pen_range" in js):
vars.rep_pen_range = js["rep_pen_range"]
if("genamt" in js):
vars.genamt = js["genamt"]
if("max_length" in js):
vars.max_length = js["max_length"]
if("ikgen" in js):
vars.ikgen = js["ikgen"]
if("formatoptns" in js):
vars.formatoptns = js["formatoptns"]
if("numseqs" in js):
vars.numseqs = js["numseqs"]
if("widepth" in js):
vars.widepth = js["widepth"]
if("useprompt" in js):
vars.useprompt = js["useprompt"]
if("adventure" in js):
vars.adventure = js["adventure"]
if("chatmode" in js):
vars.chatmode = js["chatmode"]
if("chatname" in js):
vars.chatname = js["chatname"]
if("dynamicscan" in js):
vars.dynamicscan = js["dynamicscan"]
if("nopromptgen" in js):
vars.nopromptgen = js["nopromptgen"]
if("rngpersist" in js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
if(not vars.gamestarted):
vars.authornotetemplate = vars.setauthornotetemplate
if("userscripts" in js):
vars.userscripts = []
for userscript in js["userscripts"]:
if type(userscript) is not str:
continue
userscript = userscript.strip()
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
vars.userscripts.append(userscript)
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
vars.corescript = js["corescript"]
else:
vars.corescript = "default.lua"
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
spRequest(js["softprompt"])
else:
vars.spfilename = ""
file.close()
#==================================================================#
# Don't save settings unless 2 seconds have passed without modification
#==================================================================#
@debounce(2)
def settingschanged():
print("{0}Saving settings!{1}".format(colors.GREEN, colors.END))
savesettings()
#==================================================================#
# Set value of gamesaved
#==================================================================#
@ -4490,60 +4596,6 @@ def loadRequest(loadpath, filename=None):
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
#==================================================================#
# Load a soft prompt from a file
#==================================================================#
def spRequest(filename):
vars.spfilename = ""
settingschanged()
if(len(filename) == 0):
vars.sp = None
vars.sp_length = 0
return
global np
if 'np' not in globals():
import numpy as np
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
with z.open('meta.json') as f:
vars.spmeta = json.load(f)
z.close()
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
tensor.dtype = np.uint16
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp_length = tensor.shape[-2]
vars.spmeta["n_tokens"] = vars.sp_length
if(vars.model in ("TPUMeshTransformerGPTJ",)):
rows = tensor.shape[0]
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
tensor = tensor.reshape(
tpu_mtj_backend.params["cores_per_replica"],
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
vars.spfilename = filename
settingschanged()
#==================================================================#
# Import an AIDungon game exported with Mimi's tool
#==================================================================#
@ -4888,9 +4940,6 @@ def randomGameRequest(topic, memory=""):
vars.memory = memory
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
# Load desired settings from both the model and the users config file
loadsettings()
# Prevent tokenizer from taking extra time the first time it's used
def __preempt_tokenizer():
if("tokenizer" not in globals()):
@ -4899,6 +4948,16 @@ def __preempt_tokenizer():
tokenizer.encode(utils.encodenewlines("eunoia"))
threading.Thread(target=__preempt_tokenizer).start()
# Load soft prompt specified by the settings file, if applicable
if(path.exists("settings/" + getmodelname().replace('/', '_') + ".settings")):
file = open("settings/" + getmodelname().replace('/', '_') + ".settings", "r")
js = json.load(file)
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
spRequest(js["softprompt"])
else:
vars.spfilename = ""
file.close()
# Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens()
@ -4930,7 +4989,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
def send_debug():
if vars.debug:
debug_info = ""
for variable in [["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]:
for variable in [["Newline Mode", vars.newlinemode], ["Action Length", vars.actions.get_last_key()], ["Actions Metadata Length", max(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)

View File

@ -443,9 +443,9 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_
return carry
class PenalizingCausalTransformer(CausalTransformer):
def __init__(self, config):
def __init__(self, config, **kwargs):
# Initialize
super().__init__(config)
super().__init__(config, **kwargs)
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
compiling_callback()
numseqs = numseqs_aux.shape[0]
@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
"pe_rotary_dims": 64,
"seq": 2048,
"cores_per_replica": 8,
"tokenizer_class": "GPT2TokenizerFast",
"tokenizer": "gpt2",
}
params = kwargs
if "compat" in params:
default_params["compat"] = params["compat"]
if default_params["compat"] == "fairseq_lm":
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
for param in default_params:
if param not in params:
params[param] = default_params[param]
# Load tokenizer
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
# Disable JAX warnings about these two functions having been renamed
jax.host_count = jax.process_count
jax.host_id = jax.process_index
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
global shard_xmap, batch_xmap
shard_xmap = __shard_xmap()
@ -832,6 +843,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
if not path.endswith("/"):
path += "/"
network = PenalizingCausalTransformer(params)
network = PenalizingCausalTransformer(params, dematerialized=True)
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))