mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-02 18:46:48 +01:00
Add safeguards for token budget and text formatting
* Error messages are now shown when memory, author's note, etc. exceeds budget by itself * Formatting options no longer break if there are empty chunks in the story (although there shouldn't be any in the first place) * Number of generated tokens is now kept track of from Python
This commit is contained in:
parent
6183ecd669
commit
8742453f95
110
aiserver.py
110
aiserver.py
@ -119,6 +119,7 @@ class vars:
|
||||
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
|
||||
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
|
||||
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
|
||||
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
|
||||
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
|
||||
userscripts = [] # List of userscripts to load
|
||||
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
|
||||
@ -796,7 +797,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
|
||||
vars.generated_tkns += 1
|
||||
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
|
||||
if(vars.generated_tkns >= vars.genamt):
|
||||
self.regeneration_required = False
|
||||
self.halt = False
|
||||
return True
|
||||
@ -808,7 +812,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = input_ids[i, -1].item()
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item())
|
||||
|
||||
if(not vars.dynamicscan):
|
||||
return self.regeneration_required or self.halt
|
||||
@ -1145,7 +1149,7 @@ def lua_compute_context(submission, entries, folders):
|
||||
i += 1
|
||||
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(submission, allowed_entries=allowed_entries, allowed_folders=allowed_folders, force_use_txt=True)
|
||||
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
|
||||
return txt
|
||||
return tokenizer.decode(txt)
|
||||
|
||||
#==================================================================#
|
||||
# Get property of a world info entry given its UID and property name
|
||||
@ -2241,38 +2245,53 @@ def calcsubmitbudgetheader(txt, **kwargs):
|
||||
|
||||
return winfo, mem, anotetxt, found_entries
|
||||
|
||||
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, budget_deduction=0):
|
||||
forceanote = False # In case we don't have enough actions to hit A.N. depth
|
||||
anoteadded = False # In case our budget runs out before we hit A.N. depth
|
||||
anotetkns = [] # Placeholder for Author's Note tokens
|
||||
lnanote = 0 # Placeholder for Author's Note length
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnmem = len(memtokens)
|
||||
|
||||
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnwi = len(witokens)
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
|
||||
lnanote = len(anotetkns)
|
||||
|
||||
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True)
|
||||
lnmem = len(memtokens)
|
||||
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
||||
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True)
|
||||
lnwi = len(witokens)
|
||||
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True)
|
||||
lnanote = len(anotetkns)
|
||||
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
||||
if(vars.useprompt):
|
||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt
|
||||
budget = vars.max_length - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||
else:
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||
|
||||
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
||||
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
|
||||
|
||||
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("Your submission is too long. Please either write a shorter submission or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. If you are using the Always Add Prompt setting, turning it off may help.")
|
||||
|
||||
assert budget >= 0
|
||||
|
||||
if(actionlen == 0):
|
||||
# First/Prompt action
|
||||
subtxt = vars.memory + winfo + anotetxt + vars.comregex_ai.sub('', vars.prompt)
|
||||
lnsub = lnsp + lnmem + lnwi + lnprompt + lnanote
|
||||
return subtxt, lnsub+1, lnsub+vars.genamt
|
||||
tokens = memtokens + witokens + anotetkns + prompttkns
|
||||
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
|
||||
ln = len(tokens) + lnsp
|
||||
return tokens, ln+1, ln+vars.genamt
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
@ -2285,9 +2304,10 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
for key in reversed(actions):
|
||||
chunk = vars.comregex_ai.sub('', actions[key])
|
||||
|
||||
assert budget >= 0
|
||||
if(budget <= 0):
|
||||
break
|
||||
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
|
||||
acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True)
|
||||
tknlen = len(acttkns)
|
||||
if(tknlen < budget):
|
||||
tokens = acttkns + tokens
|
||||
@ -2324,8 +2344,9 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
|
||||
tokens = memtokens + witokens + prompttkns + tokens
|
||||
|
||||
# Send completed bundle to generator
|
||||
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
|
||||
ln = len(tokens) + lnsp
|
||||
return tokenizer.decode(tokens), ln+1, ln+vars.genamt
|
||||
return tokens, ln+1, ln+vars.genamt
|
||||
|
||||
#==================================================================#
|
||||
# Take submitted text and build the text to be given to generator
|
||||
@ -2340,23 +2361,23 @@ def calcsubmit(txt):
|
||||
|
||||
# For all transformers models
|
||||
if(vars.model != "InferKit"):
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
|
||||
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
|
||||
if(actionlen == 0):
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, min, max)
|
||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, min, max)
|
||||
oairequest(tokenizer.decode(subtxt), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
else:
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(subtxt, min, max)
|
||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(subtxt, min, max)
|
||||
oairequest(tokenizer.decode(subtxt), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
|
||||
@ -2421,13 +2442,14 @@ def calcsubmit(txt):
|
||||
#==================================================================#
|
||||
|
||||
def _generate(txt, minimum, maximum, found_entries):
|
||||
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
|
||||
gen_in = torch.tensor(txt, dtype=torch.long)[None]
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
model.config.vocab_size,
|
||||
model.config.vocab_size + vars.sp.shape[0],
|
||||
)
|
||||
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
|
||||
assert gen_in.shape[-1] + vars.genamt <= vars.max_length
|
||||
|
||||
if(vars.hascuda and vars.usegpu):
|
||||
gen_in = gen_in.to(vars.gpu_device)
|
||||
@ -2459,11 +2481,14 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
num_return_sequences=numseqs
|
||||
)
|
||||
already_generated += len(genout[0]) - len(gen_in[0])
|
||||
assert already_generated <= vars.genamt
|
||||
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
|
||||
break
|
||||
assert genout.ndim >= 2
|
||||
assert genout.shape[0] == vars.numseqs
|
||||
if(already_generated != vars.lua_koboldbridge.generated_cols):
|
||||
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
|
||||
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
|
||||
if(already_generated != vars.generated_tkns):
|
||||
raise RuntimeError("WI scanning error")
|
||||
for r in range(vars.numseqs):
|
||||
for c in range(already_generated):
|
||||
@ -2474,8 +2499,8 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions)
|
||||
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
|
||||
max_length = len(max(encoded, key=len))
|
||||
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
|
||||
genout = torch.cat(
|
||||
@ -2492,6 +2517,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
device=genout.device,
|
||||
)
|
||||
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
|
||||
assert genout.shape[-1] + vars.genamt - already_generated <= vars.max_length
|
||||
diff = genout.shape[-1] - gen_in.shape[-1]
|
||||
minimum += diff
|
||||
maximum += diff
|
||||
@ -2503,14 +2529,16 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||
|
||||
|
||||
def generate(txt, minimum, maximum, found_entries=None):
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
vars.lastctx = txt
|
||||
vars.lastctx = tokenizer.decode(txt)
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
@ -2536,7 +2564,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||
return
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.lua_koboldbridge.generated_cols] = genout[i, -1].item()
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
|
||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
||||
|
||||
execute_outmod()
|
||||
@ -2707,7 +2735,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
found_entries = set()
|
||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, txt, colors.END))
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||
|
||||
# Submit input text to generator
|
||||
try:
|
||||
@ -2737,7 +2765,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
|
||||
genout = tpool.execute(
|
||||
tpu_mtj_backend.infer,
|
||||
txt,
|
||||
np.uint32(txt),
|
||||
gen_len = maximum-minimum+1,
|
||||
temp=vars.temp,
|
||||
top_p=vars.top_p,
|
||||
@ -2804,8 +2832,8 @@ def getnewcontent(txt):
|
||||
return txt
|
||||
|
||||
# Tokenize the last context and the generated content
|
||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
|
||||
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
|
||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True)
|
||||
txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True)
|
||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||
|
||||
# Remove the context from the returned text
|
||||
|
@ -268,7 +268,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
|
||||
|
||||
def infer(
|
||||
context: str,
|
||||
context: np.array,
|
||||
top_p=0.9,
|
||||
temp=0.5,
|
||||
top_k=0,
|
||||
@ -281,7 +281,7 @@ def infer(
|
||||
) -> List[str]:
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
total_batch = 1
|
||||
tokens = np.uint32(tokenizer.encode(context, max_length=params["seq"] - (soft_tokens.shape[0] if soft_tokens is not None else 0), truncation=True))
|
||||
tokens = context
|
||||
if(soft_tokens is not None):
|
||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
||||
provided_ctx = tokens.shape[0]
|
||||
|
12
utils.py
12
utils.py
@ -73,13 +73,15 @@ def addsentencespacing(txt, vars):
|
||||
# Get last character of last action
|
||||
if(len(vars.actions) > 0):
|
||||
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
||||
lastchar = vars.actions[vars.actions.get_last_key()][-1]
|
||||
action = vars.actions[vars.actions.get_last_key()]
|
||||
lastchar = action[-1] if len(action) else ""
|
||||
else:
|
||||
# Last action is blank, this should never happen, but
|
||||
# since it did let's bail out.
|
||||
return txt
|
||||
else:
|
||||
lastchar = vars.prompt[-1]
|
||||
action = vars.prompt
|
||||
lastchar = action[-1] if len(action) else ""
|
||||
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):
|
||||
txt = " " + txt
|
||||
return txt
|
||||
@ -88,13 +90,15 @@ def singlelineprocessing(txt, vars):
|
||||
txt = vars.regex_sl.sub('', txt)
|
||||
if(len(vars.actions) > 0):
|
||||
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
|
||||
lastchar = vars.actions[vars.actions.get_last_key()][-1]
|
||||
action = vars.actions[vars.actions.get_last_key()]
|
||||
lastchar = action[-1] if len(action) else ""
|
||||
else:
|
||||
# Last action is blank, this should never happen, but
|
||||
# since it did let's bail out.
|
||||
return txt
|
||||
else:
|
||||
lastchar = vars.prompt[-1]
|
||||
action = vars.prompt
|
||||
lastchar = action[-1] if len(action) else ""
|
||||
if(lastchar != "\n"):
|
||||
txt = txt + "\n"
|
||||
return txt
|
||||
|
Loading…
x
Reference in New Issue
Block a user