Merge pull request #76 from VE-FORBRYDERNE/newline

Fix fairseq newline handling issues
This commit is contained in:
henk717 2022-02-14 18:10:25 +01:00 committed by GitHub
commit ca5b9f968f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 42 deletions

View File

@ -211,6 +211,8 @@ class vars:
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
debug = False # If set to true, will send debug information to the client for display debug = False # If set to true, will send debug information to the client for display
utils.vars = vars
#==================================================================# #==================================================================#
# Function to get model selection at startup # Function to get model selection at startup
#==================================================================# #==================================================================#
@ -916,7 +918,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
return self.regeneration_required or self.halt return self.regeneration_required or self.halt
tail = input_ids[..., -vars.generated_tkns:] tail = input_ids[..., -vars.generated_tkns:]
for i, t in enumerate(tail): for i, t in enumerate(tail):
decoded = tokenizer.decode(t) decoded = utils.decodenewlines(tokenizer.decode(t))
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
found -= self.excluded_world_info[i] found -= self.excluded_world_info[i]
if(len(found) != 0): if(len(found) != 0):
@ -1118,7 +1120,7 @@ else:
return excluded_world_info, regeneration_required, halt return excluded_world_info, regeneration_required, halt
for i, t in enumerate(generated): for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) decoded = utils.decodenewlines(tokenizer.decode(past[i])) + utils.decodenewlines(tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]))
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
found -= excluded_world_info[i] found -= excluded_world_info[i]
if(len(found) != 0): if(len(found) != 0):
@ -1327,7 +1329,7 @@ def lua_decode(tokens):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
global tokenizer global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.decode(tokens) return utils.decodenewlines(tokenizer.decode(tokens))
#==================================================================# #==================================================================#
# Encode string into list of token IDs using current tokenizer # Encode string into list of token IDs using current tokenizer
@ -1339,7 +1341,7 @@ def lua_encode(string):
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
global tokenizer global tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
return tokenizer.encode(string, max_length=int(4e9), truncation=True) return tokenizer.encode(utils.encodenewlines(string), max_length=int(4e9), truncation=True)
#==================================================================# #==================================================================#
# Computes context given a submission, Lua array of entry UIDs and a Lua array # Computes context given a submission, Lua array of entry UIDs and a Lua array
@ -1379,7 +1381,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
anotetxt, anotetxt,
actions, actions,
) )
return tokenizer.decode(txt) return utils.decodenewlines(tokenizer.decode(txt))
#==================================================================# #==================================================================#
# Get property of a world info entry given its UID and property name # Get property of a world info entry given its UID and property name
@ -2455,10 +2457,6 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
if(len(data)): if(len(data)):
data = f"\n{vars.chatname} : {data}\n" data = f"\n{vars.chatname} : {data}\n"
# </s> mode
if(vars.newlinemode == "s"):
data = data.replace('\n', "</s>")
# If we're not continuing, store a copy of the raw input # If we're not continuing, store a copy of the raw input
if(data != ""): if(data != ""):
vars.lastact = data vars.lastact = data
@ -2706,21 +2704,21 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
# Calculate token budget # Calculate token budget
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True) prompttkns = tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', vars.prompt)), max_length=int(2e9), truncation=True)
lnprompt = len(prompttkns) lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True) memtokens = tokenizer.encode(utils.encodenewlines(mem), max_length=int(2e9), truncation=True)
lnmem = len(memtokens) lnmem = len(memtokens)
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction): if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True) witokens = tokenizer.encode(utils.encodenewlines(winfo), max_length=int(2e9), truncation=True)
lnwi = len(witokens) lnwi = len(witokens)
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction): if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
if(anotetxt != ""): if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True) anotetkns = tokenizer.encode(utils.encodenewlines(anotetxt), max_length=int(2e9), truncation=True)
lnanote = len(anotetkns) lnanote = len(anotetkns)
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction): if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
@ -2730,7 +2728,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
else: else:
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0 lnsubmission = len(tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', submission)), max_length=int(2e9), truncation=True)) if submission is not None else 0
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0 maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction): if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
@ -2759,7 +2757,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
assert budget >= 0 assert budget >= 0
if(budget <= 0): if(budget <= 0):
break break
acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True) acttkns = tokenizer.encode(utils.encodenewlines(chunk), max_length=int(2e9), truncation=True)
tknlen = len(acttkns) tknlen = len(acttkns)
if(tknlen < budget): if(tknlen < budget):
tokens = acttkns + tokens tokens = acttkns + tokens
@ -2818,18 +2816,18 @@ def calcsubmit(txt):
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(tokenizer.decode(subtxt), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(tokenizer.decode(subtxt), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
else: else:
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
generate(subtxt, min, max, found_entries=found_entries) generate(subtxt, min, max, found_entries=found_entries)
elif(vars.model == "Colab"): elif(vars.model == "Colab"):
sendtocolab(tokenizer.decode(subtxt), min, max) sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "OAI"): elif(vars.model == "OAI"):
oairequest(tokenizer.decode(subtxt), min, max) oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
elif(vars.model == "TPUMeshTransformerGPTJ"): elif(vars.model == "TPUMeshTransformerGPTJ"):
tpumtjgenerate(subtxt, min, max, found_entries=found_entries) tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
@ -2947,7 +2945,7 @@ def _generate(txt, minimum, maximum, found_entries):
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1] genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
encoded = [] encoded = []
for i in range(vars.numseqs): for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:]) txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
@ -2986,10 +2984,10 @@ def generate(txt, minimum, maximum, found_entries=None):
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
if not vars.quiet: if not vars.quiet:
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
vars.lastctx = tokenizer.decode(txt) vars.lastctx = utils.decodenewlines(tokenizer.decode(txt))
# Clear CUDA cache if using GPU # Clear CUDA cache if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)): if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
@ -3016,7 +3014,7 @@ def generate(txt, minimum, maximum, found_entries=None):
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item()) vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:]) vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
execute_outmod() execute_outmod()
if(vars.lua_koboldbridge.regeneration_required): if(vars.lua_koboldbridge.regeneration_required):
@ -3026,7 +3024,7 @@ def generate(txt, minimum, maximum, found_entries=None):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str assert type(genout[-1]["generated_text"]) is str
else: else:
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))} for tokens in genout]
if(len(genout) == 1): if(len(genout) == 1):
genresult(genout[0]["generated_text"]) genresult(genout[0]["generated_text"])
@ -3239,7 +3237,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
if not vars.quiet: if not vars.quiet:
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
vars._actions = vars.actions vars._actions = vars.actions
vars._prompt = vars.prompt vars._prompt = vars.prompt
@ -3282,7 +3280,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
encoded = [] encoded = []
for i in range(vars.numseqs): for i in range(vars.numseqs):
txt = tokenizer.decode(past[i]) txt = utils.decodenewlines(tokenizer.decode(past[i]))
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
@ -3334,7 +3332,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
return return
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i]) vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i]))
genout = past genout = past
execute_outmod() execute_outmod()
@ -3345,7 +3343,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
assert type(genout[-1]["generated_text"]) is str assert type(genout[-1]["generated_text"]) is str
else: else:
genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout] genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout]
if(len(genout) == 1): if(len(genout) == 1):
genresult(genout[0]["generated_text"]) genresult(genout[0]["generated_text"])
@ -3373,14 +3371,14 @@ def getnewcontent(txt):
return txt return txt
# Tokenize the last context and the generated content # Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True) ctxtokens = tokenizer.encode(utils.encodenewlines(vars.lastctx), max_length=int(2e9), truncation=True)
txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True) txttokens = tokenizer.encode(utils.encodenewlines(txt), max_length=int(2e9), truncation=True)
dif = (len(txttokens) - len(ctxtokens)) * -1 dif = (len(txttokens) - len(ctxtokens)) * -1
# Remove the context from the returned text # Remove the context from the returned text
newtokens = txttokens[dif:] newtokens = txttokens[dif:]
return tokenizer.decode(newtokens) return utils.decodenewlines(tokenizer.decode(newtokens))
#==================================================================# #==================================================================#
# Applies chosen formatting options to text submitted to AI # Applies chosen formatting options to text submitted to AI
@ -3396,9 +3394,6 @@ def applyinputformatting(txt):
# Applies chosen formatting options to text returned from AI # Applies chosen formatting options to text returned from AI
#==================================================================# #==================================================================#
def applyoutputformatting(txt): def applyoutputformatting(txt):
# Revert S mode on output to maintain compatibility
txt = txt.replace('</s>', "\n")
# Use standard quotes and apostrophes # Use standard quotes and apostrophes
txt = utils.fixquotes(txt) txt = utils.fixquotes(txt)
@ -4856,8 +4851,8 @@ loadsettings()
def __preempt_tokenizer(): def __preempt_tokenizer():
if("tokenizer" not in globals()): if("tokenizer" not in globals()):
return return
tokenizer.decode([25678, 559]) utils.decodenewlines(tokenizer.decode([25678, 559]))
tokenizer.encode("eunoia") tokenizer.encode(utils.encodenewlines("eunoia"))
threading.Thread(target=__preempt_tokenizer).start() threading.Thread(target=__preempt_tokenizer).start()
# Precompile TPU backend if required # Precompile TPU backend if required
@ -4891,7 +4886,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
def send_debug(): def send_debug():
if vars.debug: if vars.debug:
debug_info = "" debug_info = ""
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]: for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]:
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1]) debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True) emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)

View File

@ -1,6 +1,8 @@
from threading import Timer from threading import Timer
import re import re
vars = None
#==================================================================# #==================================================================#
# Decorator to prevent a function's actions from being run until # Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called # at least x seconds have passed without the function being called
@ -111,8 +113,15 @@ def cleanfilename(filename):
filename = "".join(c for c in filename if c not in filteredcharacters).rstrip() filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
return filename return filename
#==================================================================#
# Newline substitution for fairseq models
#==================================================================#
def encodenewlines(txt):
if(vars.newlinemode == "s"):
return txt.replace('\n', "</s>")
return txt
def decodenewlines(txt):
if(vars.newlinemode == "s"):
return txt.replace("</s>", '\n')
return txt