Add safeguards for token budget and text formatting

* Error messages are now shown when memory, author's note, etc. exceeds
  budget by itself
* Formatting options no longer break if there are empty chunks in the
  story (although there shouldn't be any in the first place)
* Number of generated tokens is now kept track of from Python
This commit is contained in:
Gnome Ann 2021-12-26 18:29:54 -05:00
parent 6183ecd669
commit 8742453f95
3 changed files with 82 additions and 50 deletions

View File

@ -119,6 +119,7 @@ class vars:
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error) lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
userscripts = [] # List of userscripts to load userscripts = [] # List of userscripts to load
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
@ -796,7 +797,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
scores: torch.FloatTensor, scores: torch.FloatTensor,
**kwargs, **kwargs,
) -> bool: ) -> bool:
if(vars.lua_koboldbridge.generated_cols >= vars.genamt): vars.generated_tkns += 1
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
if(vars.generated_tkns >= vars.genamt):
self.regeneration_required = False self.regeneration_required = False
self.halt = False self.halt = False
return True return True
@ -808,7 +812,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
vars.lua_koboldbridge.regeneration_required = False vars.lua_koboldbridge.regeneration_required = False
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item() vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item())
if(not vars.dynamicscan): if(not vars.dynamicscan):
return self.regeneration_required or self.halt return self.regeneration_required or self.halt
@ -1145,7 +1149,7 @@ def lua_compute_context(submission, entries, folders):
i += 1 i += 1
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True) winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
return txt return tokenizer.decode(txt)
#==================================================================# #==================================================================#
# Get property of a world info entry given its UID and property name # Get property of a world info entry given its UID and property name
@ -2241,38 +2245,53 @@ def calcsubmitbudgetheader(txt, **kwargs):
return winfo, mem, anotetxt, found_entries return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions): def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, budget_deduction=0):
forceanote = False # In case we don't have enough actions to hit A.N. depth forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0 lnsp = vars.sp.shape[0] if vars.sp is not None else 0
# Calculate token budget
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True)
lnmem = len(memtokens)
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True)
lnwi = len(witokens)
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True)
lnanote = len(anotetkns)
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
if(vars.useprompt): if(vars.useprompt):
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
else: else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("Your submission is too long. Please either write a shorter submission or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. If you are using the Always Add Prompt setting, turning it off may help.")
assert budget >= 0
if(actionlen == 0): if(actionlen == 0):
# First/Prompt action # First/Prompt action
subtxt = vars.memory + winfo + anotetxt + vars.comregex_ai.sub('', vars.prompt) tokens = memtokens + witokens + anotetkns + prompttkns
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
return subtxt, lnsub+1, lnsub+vars.genamt ln = len(tokens) + lnsp
return tokens, ln+1, ln+vars.genamt
else: else:
tokens = [] tokens = []
@ -2285,9 +2304,10 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
for key in reversed(actions): for key in reversed(actions):
chunk = vars.comregex_ai.sub('', actions[key]) chunk = vars.comregex_ai.sub('', actions[key])
assert budget >= 0
if(budget <= 0): if(budget <= 0):
break break
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True) acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True)
tknlen = len(acttkns) tknlen = len(acttkns)
if(tknlen < budget): if(tknlen < budget):
tokens = acttkns + tokens tokens = acttkns + tokens
@ -2312,7 +2332,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
prompttkns = prompttkns[-budget:] prompttkns = prompttkns[-budget:]
else: else:
prompttkns = [] prompttkns = []
# Did we get to add the A.N.? If not, do it here # Did we get to add the A.N.? If not, do it here
if(anotetxt != ""): if(anotetxt != ""):
if((not anoteadded) or forceanote): if((not anoteadded) or forceanote):
@ -2322,10 +2342,11 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
else: else:
# Prepend Memory, WI, and Prompt before action tokens # Prepend Memory, WI, and Prompt before action tokens
tokens = memtokens + witokens + prompttkns + tokens tokens = memtokens + witokens + prompttkns + tokens
# Send completed bundle to generator # Send completed bundle to generator
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
ln = len(tokens) + lnsp ln = len(tokens) + lnsp
return tokenizer.decode(tokens), ln+1, ln+vars.genamt return tokens, ln+1, ln+vars.genamt
#==================================================================# #==================================================================#
# Take submitted text and build the text to be given to generator # Take submitted text and build the text to be given to generator
@ -2340,23 +2361,23 @@ def calcsubmit(txt):
# For all transformers models # For all transformers models
if(vars.model != "InferKit"): if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions) subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
if(actionlen == 0): if(actionlen == 0):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max) sendtocolab(tokenizer.decode(subtxt), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(subtxt, min, max) oairequest(tokenizer.decode(subtxt), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else: else:
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(subtxt, min, max) sendtocolab(tokenizer.decode(subtxt), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(subtxt, min, max) oairequest(tokenizer.decode(subtxt), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
@ -2421,13 +2442,14 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
def _generate(txt, minimum, maximum, found_entries): def _generate(txt, minimum, maximum, found_entries):
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long() gen_in = torch.tensor(txt, dtype=torch.long)[None]
if(vars.sp is not None): if(vars.sp is not None):
soft_tokens = torch.arange( soft_tokens = torch.arange(
model.config.vocab_size, model.config.vocab_size,
model.config.vocab_size + vars.sp.shape[0], model.config.vocab_size + vars.sp.shape[0],
) )
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1) gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
assert gen_in.shape[-1] + vars.genamt <= vars.max_length
if(vars.hascuda and vars.usegpu): if(vars.hascuda and vars.usegpu):
gen_in = gen_in.to(vars.gpu_device) gen_in = gen_in.to(vars.gpu_device)
@ -2459,11 +2481,14 @@ def _generate(txt, minimum, maximum, found_entries):
num_return_sequences=numseqs num_return_sequences=numseqs
) )
already_generated += len(genout[0]) - len(gen_in[0]) already_generated += len(genout[0]) - len(gen_in[0])
assert already_generated <= vars.genamt
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required): if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
break break
assert genout.ndim >= 2 assert genout.ndim >= 2
assert genout.shape[0] == vars.numseqs assert genout.shape[0] == vars.numseqs
if(already_generated != vars.lua_koboldbridge.generated_cols): if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
if(already_generated != vars.generated_tkns):
raise RuntimeError("WI scanning error") raise RuntimeError("WI scanning error")
for r in range(vars.numseqs): for r in range(vars.numseqs):
for c in range(already_generated): for c in range(already_generated):
@ -2474,8 +2499,8 @@ def _generate(txt, minimum, maximum, found_entries):
txt = tokenizer.decode(genout[i, -already_generated:]) txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device)) encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
max_length = len(max(encoded, key=len)) max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat( genout = torch.cat(
@ -2492,6 +2517,7 @@ def _generate(txt, minimum, maximum, found_entries):
device=genout.device, device=genout.device,
) )
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1) genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
assert genout.shape[-1] + vars.genamt - already_generated <= vars.max_length
diff = genout.shape[-1] - gen_in.shape[-1] diff = genout.shape[-1] - gen_in.shape[-1]
minimum += diff minimum += diff
maximum += diff maximum += diff
@ -2503,14 +2529,16 @@ def _generate(txt, minimum, maximum, found_entries):
def generate(txt, minimum, maximum, found_entries=None): def generate(txt, minimum, maximum, found_entries=None):
vars.generated_tkns = 0
if(found_entries is None): if(found_entries is None):
found_entries = set() found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
vars.lastctx = txt vars.lastctx = tokenizer.decode(txt)
# Clear CUDA cache if using GPU # Clear CUDA cache if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)): if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
@ -2536,7 +2564,7 @@ def generate(txt, minimum, maximum, found_entries=None):
return return
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item() vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:]) vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
execute_outmod() execute_outmod()
@ -2707,7 +2735,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
found_entries = set() found_entries = set()
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
# Submit input text to generator # Submit input text to generator
try: try:
@ -2737,7 +2765,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
genout = tpool.execute( genout = tpool.execute(
tpu_mtj_backend.infer, tpu_mtj_backend.infer,
txt, np.uint32(txt),
gen_len = maximum-minimum+1, gen_len = maximum-minimum+1,
temp=vars.temp, temp=vars.temp,
top_p=vars.top_p, top_p=vars.top_p,
@ -2804,8 +2832,8 @@ def getnewcontent(txt):
return txt return txt
# Tokenize the last context and the generated content # Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True) ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True)
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True) txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True)
dif = (len(txttokens) - len(ctxtokens)) * -1 dif = (len(txttokens) - len(ctxtokens)) * -1
# Remove the context from the returned text # Remove the context from the returned text

View File

@ -268,7 +268,7 @@ class PenalizingCausalTransformer(CausalTransformer):
def infer( def infer(
context: str, context: np.array,
top_p=0.9, top_p=0.9,
temp=0.5, temp=0.5,
top_k=0, top_k=0,
@ -281,7 +281,7 @@ def infer(
) -> List[str]: ) -> List[str]:
maps.thread_resources.env = thread_resources_env maps.thread_resources.env = thread_resources_env
total_batch = 1 total_batch = 1
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True)) tokens = context
if(soft_tokens is not None): if(soft_tokens is not None):
tokens = np.uint32(np.concatenate((soft_tokens, tokens))) tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
provided_ctx = tokens.shape[0] provided_ctx = tokens.shape[0]

View File

@ -73,13 +73,15 @@ def addsentencespacing(txt, vars):
# Get last character of last action # Get last character of last action
if(len(vars.actions) > 0): if(len(vars.actions) > 0):
if(len(vars.actions[vars.actions.get_last_key()]) > 0): if(len(vars.actions[vars.actions.get_last_key()]) > 0):
lastchar = vars.actions[vars.actions.get_last_key()][-1] action = vars.actions[vars.actions.get_last_key()]
lastchar = action[-1] if len(action) else ""
else: else:
# Last action is blank, this should never happen, but # Last action is blank, this should never happen, but
# since it did let's bail out. # since it did let's bail out.
return txt return txt
else: else:
lastchar = vars.prompt[-1] action = vars.prompt
lastchar = action[-1] if len(action) else ""
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"): if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):
txt = " " + txt txt = " " + txt
return txt return txt
@ -88,13 +90,15 @@ def singlelineprocessing(txt, vars):
txt = vars.regex_sl.sub('', txt) txt = vars.regex_sl.sub('', txt)
if(len(vars.actions) > 0): if(len(vars.actions) > 0):
if(len(vars.actions[vars.actions.get_last_key()]) > 0): if(len(vars.actions[vars.actions.get_last_key()]) > 0):
lastchar = vars.actions[vars.actions.get_last_key()][-1] action = vars.actions[vars.actions.get_last_key()]
lastchar = action[-1] if len(action) else ""
else: else:
# Last action is blank, this should never happen, but # Last action is blank, this should never happen, but
# since it did let's bail out. # since it did let's bail out.
return txt return txt
else: else:
lastchar = vars.prompt[-1] action = vars.prompt
lastchar = action[-1] if len(action) else ""
if(lastchar != "\n"): if(lastchar != "\n"):
txt = txt + "\n" txt = txt + "\n"
return txt return txt