Merge branch 'united' into patch

This commit is contained in:
Gnome Ann
2022-01-16 00:36:55 -05:00
29 changed files with 2370 additions and 124 deletions

View File

@ -22,7 +22,7 @@ import packaging
import contextlib
import traceback
import threading
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
import requests
import html
@ -69,7 +69,7 @@ modellist = [
["C1 6B (Chatbot)", "hakurei/c1-6B", "12GB"],
["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "6GB"],
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "6GB"],
["Horni-LN 2.7B (Novel/NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "6GB"],
["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "6GB"],
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "6GB"],
["GPT-J 6B", "EleutherAI/gpt-j-6B", "12GB"],
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "6GB"],
@ -158,6 +158,7 @@ class vars:
spmeta = None # Metadata of current soft prompt, or None if not using a soft prompt
sp = None # Current soft prompt tensor (as a NumPy array)
sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt
has_genmod = False # Whether or not at least one loaded Lua userscript has a generation modifier
svowname = "" # Filename that was flagged for overwrite confirm
saveow = False # Whether or not overwrite confirm has been displayed
genseqs = [] # Temporary storage for generated sequences
@ -185,6 +186,7 @@ class vars:
remote = False
nopromptgen = False
rngpersist = False
nogenmod = False
#==================================================================#
# Function to get model selection at startup
@ -387,6 +389,7 @@ parser.add_argument("--breakmodel_gpulayers", type=str, help="If using a model t
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
args: argparse.Namespace = None
if(os.environ.get("KOBOLDAI_ARGS") is not None):
@ -394,8 +397,14 @@ if(os.environ.get("KOBOLDAI_ARGS") is not None):
args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
else:
args = parser.parse_args()
vars.model = args.model;
if args.colab:
args.remote = True;
args.override_rename = True;
args.override_delete = True;
if args.remote:
vars.remote = True;
@ -453,7 +462,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.model_type = "gpt_neo"
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
vars.hascuda = torch.cuda.is_available()
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj")
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not args.colab
if(args.breakmodel is not None and args.breakmodel):
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
if(args.breakmodel_layers is not None):
@ -930,7 +939,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
else:
print("Model does not exist locally, attempting to download from Huggingface...")
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache/")
except ValueError as e:
@ -940,11 +948,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
model = model.half()
import shutil
shutil.rmtree("cache/")
model.save_pretrained(vars.model.replace('/', '_'))
tokenizer.save_pretrained(vars.model.replace('/', '_'))
if not args.colab:
model = model.half()
import shutil
shutil.rmtree("cache/")
model.save_pretrained(vars.model.replace('/', '_'))
tokenizer.save_pretrained(vars.model.replace('/', '_'))
if(vars.hascuda):
if(vars.usegpu):
@ -991,7 +1001,7 @@ else:
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = tensor
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
soft_tokens = np.arange(
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
@ -999,6 +1009,49 @@ else:
)
return soft_tokens
def tpumtjgenerate_warper_callback(scores) -> "np.array":
scores_shape = scores.shape
scores_list = scores.tolist()
vars.lua_koboldbridge.logits = vars.lua_state.table()
for r, row in enumerate(scores_list):
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
execute_genmod()
scores = np.array(
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
dtype=scores.dtype,
)
assert scores.shape == scores_shape
return scores
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
vars.generated_tkns += 1
assert len(excluded_world_info) == len(generated)
regeneration_required = vars.lua_koboldbridge.regeneration_required
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
vars.lua_koboldbridge.regeneration_required = False
global past
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
if(not vars.dynamicscan or halt):
return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
break
return excluded_world_info, regeneration_required, halt
# If we're running Colab or OAI, we still need a tokenizer.
if(vars.model == "Colab"):
from transformers import GPT2TokenizerFast
@ -1011,21 +1064,12 @@ else:
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
import tpu_mtj_backend
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
tokenizer = tpu_mtj_backend.tokenizer
soft_tokens = tpumtjgetsofttokens()
threading.Thread( # Compile backend code in background
target=tpu_mtj_backend.infer,
args=(np.uint32((23403, 727, 20185)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"gen_len": 1,
"numseqs": vars.numseqs,
},
).start()
# Set up Flask routes
@app.route('/')
@ -1133,13 +1177,18 @@ def load_lua_scripts():
modulenames.append(lst[i]["modulename"])
descriptions.append(lst[i]["description"])
vars.has_genmod = False
try:
vars.lua_koboldbridge.obliterate_multiverse()
tpool.execute(vars.lua_koboldbridge.load_corescript, vars.corescript)
tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions)
vars.has_genmod = tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions)
vars.lua_running = True
except lupa.LuaError as e:
vars.lua_koboldbridge.obliterate_multiverse()
try:
vars.lua_koboldbridge.obliterate_multiverse()
except:
pass
vars.lua_running = False
if(vars.serverstarted):
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
@ -1996,6 +2045,10 @@ def get_message(msg):
vars.rngpersist = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setnogenmod'):
vars.nogenmod = msg['data']
settingschanged()
refresh_settings()
elif(not vars.remote and msg['cmd'] == 'importwi'):
wiimportrequest()
@ -2066,6 +2119,8 @@ def savesettings():
js["dynamicscan"] = vars.dynamicscan
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["antemplate"] = vars.setauthornotetemplate
js["userscripts"] = vars.userscripts
@ -2131,6 +2186,8 @@ def loadsettings():
vars.nopromptgen = js["nopromptgen"]
if("rngpersist" in js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
@ -2896,32 +2953,90 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None):
found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
vars._actions = vars.actions
vars._prompt = vars.prompt
if(vars.dynamicscan):
vars._actions = vars._actions.copy()
# Submit input text to generator
try:
if(vars.dynamicscan):
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
soft_tokens = tpumtjgetsofttokens()
genout = tpool.execute(
tpu_mtj_backend.infer,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
global past
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
while(True):
genout, n_generated, regeneration_required, halt = tpool.execute(
tpu_mtj_backend.infer_dynamic,
context,
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
excluded_world_info=found_entries,
)
past = np.pad(past, ((0, 0), (0, n_generated)))
for r in range(vars.numseqs):
for c in range(vars.lua_koboldbridge.generated_cols):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
if(halt or not regeneration_required):
break
print("(regeneration triggered)")
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))
max_length = len(max(encoded, key=len))
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
context = np.concatenate(
(
encoded,
past,
),
axis=-1,
)
else:
genout = tpool.execute(
tpu_mtj_backend.infer_static,
np.uint32(txt),
gen_len = maximum-minimum+1,
temp=vars.temp,
top_p=vars.top_p,
top_k=vars.top_k,
tfs=vars.tfs,
numseqs=vars.numseqs,
repetition_penalty=vars.rep_pen,
soft_embeddings=vars.sp,
soft_tokens=soft_tokens,
)
past = genout
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
except Exception as e:
if(issubclass(type(e), lupa.LuaError)):
@ -2937,10 +3052,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
genout = past
execute_outmod()
if(vars.lua_koboldbridge.regeneration_required):
@ -3103,6 +3218,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -4014,7 +4130,7 @@ def spRequest(filename):
-1,
tpu_mtj_backend.params["d_model"],
)
vars.sp = np.float32(tensor)
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
else:
vars.sp = torch.from_numpy(tensor)
@ -4359,6 +4475,34 @@ def randomGameRequest(topic, memory=""):
loadmodelsettings()
loadsettings()
# Precompile TPU backend if required
if(vars.model in ("TPUMeshTransformerGPTJ",)):
soft_tokens = tpumtjgetsofttokens()
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
threading.Thread(
target=tpu_mtj_backend.infer_dynamic,
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"gen_len": 1,
"use_callback": False,
"numseqs": vars.numseqs,
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
},
).start()
else:
threading.Thread(
target=tpu_mtj_backend.infer_static,
args=(np.uint32((23403, 727, 20185)),),
kwargs={
"soft_embeddings": vars.sp,
"soft_tokens": soft_tokens,
"gen_len": 1,
"numseqs": vars.numseqs,
},
).start()
#==================================================================#
# Final startup commands to launch Flask app
#==================================================================#