Fix fairseq newline handling issues
This commit is contained in:
parent
c1af8f72c3
commit
f682c1229a
69
aiserver.py
69
aiserver.py
|
@ -211,6 +211,8 @@ class vars:
|
|||
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
|
||||
debug = False # If set to true, will send debug information to the client for display
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
#==================================================================#
|
||||
# Function to get model selection at startup
|
||||
#==================================================================#
|
||||
|
@ -916,7 +918,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||
return self.regeneration_required or self.halt
|
||||
tail = input_ids[..., -vars.generated_tkns:]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = tokenizer.decode(t)
|
||||
decoded = utils.decodenewlines(tokenizer.decode(t))
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||
found -= self.excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
|
@ -1118,7 +1120,7 @@ else:
|
|||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
for i, t in enumerate(generated):
|
||||
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
||||
decoded = utils.decodenewlines(tokenizer.decode(past[i])) + utils.decodenewlines(tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]))
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||
found -= excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
|
@ -1327,7 +1329,7 @@ def lua_decode(tokens):
|
|||
from transformers import GPT2TokenizerFast
|
||||
global tokenizer
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
return tokenizer.decode(tokens)
|
||||
return utils.decodenewlines(tokenizer.decode(tokens))
|
||||
|
||||
#==================================================================#
|
||||
# Encode string into list of token IDs using current tokenizer
|
||||
|
@ -1339,7 +1341,7 @@ def lua_encode(string):
|
|||
from transformers import GPT2TokenizerFast
|
||||
global tokenizer
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
||||
return tokenizer.encode(utils.encodenewlines(string), max_length=int(4e9), truncation=True)
|
||||
|
||||
#==================================================================#
|
||||
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
||||
|
@ -1379,7 +1381,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
|
|||
anotetxt,
|
||||
actions,
|
||||
)
|
||||
return tokenizer.decode(txt)
|
||||
return utils.decodenewlines(tokenizer.decode(txt))
|
||||
|
||||
#==================================================================#
|
||||
# Get property of a world info entry given its UID and property name
|
||||
|
@ -2455,10 +2457,6 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
|||
if(len(data)):
|
||||
data = f"\n{vars.chatname} : {data}\n"
|
||||
|
||||
# </s> mode
|
||||
if(vars.newlinemode == "s"):
|
||||
data = data.replace('\n', "</s>")
|
||||
|
||||
# If we're not continuing, store a copy of the raw input
|
||||
if(data != ""):
|
||||
vars.lastact = data
|
||||
|
@ -2706,21 +2704,21 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
|||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||
|
||||
# Calculate token budget
|
||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True)
|
||||
prompttkns = tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', vars.prompt)), max_length=int(2e9), truncation=True)
|
||||
lnprompt = len(prompttkns)
|
||||
|
||||
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True)
|
||||
memtokens = tokenizer.encode(utils.encodenewlines(mem), max_length=int(2e9), truncation=True)
|
||||
lnmem = len(memtokens)
|
||||
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
||||
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True)
|
||||
witokens = tokenizer.encode(utils.encodenewlines(winfo), max_length=int(2e9), truncation=True)
|
||||
lnwi = len(witokens)
|
||||
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
||||
if(anotetxt != ""):
|
||||
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True)
|
||||
anotetkns = tokenizer.encode(utils.encodenewlines(anotetxt), max_length=int(2e9), truncation=True)
|
||||
lnanote = len(anotetkns)
|
||||
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||
|
@ -2730,7 +2728,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
|||
else:
|
||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||
|
||||
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
||||
lnsubmission = len(tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', submission)), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
||||
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
|
||||
|
||||
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||
|
@ -2759,7 +2757,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
|||
assert budget >= 0
|
||||
if(budget <= 0):
|
||||
break
|
||||
acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True)
|
||||
acttkns = tokenizer.encode(utils.encodenewlines(chunk), max_length=int(2e9), truncation=True)
|
||||
tknlen = len(acttkns)
|
||||
if(tknlen < budget):
|
||||
tokens = acttkns + tokens
|
||||
|
@ -2818,18 +2816,18 @@ def calcsubmit(txt):
|
|||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(tokenizer.decode(subtxt), min, max)
|
||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
else:
|
||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||
generate(subtxt, min, max, found_entries=found_entries)
|
||||
elif(vars.model == "Colab"):
|
||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
||||
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "OAI"):
|
||||
oairequest(tokenizer.decode(subtxt), min, max)
|
||||
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||
|
||||
|
@ -2947,7 +2945,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
|||
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
||||
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
|
@ -2986,10 +2984,10 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
if not vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
|
||||
# Store context in memory to use it for comparison with generated content
|
||||
vars.lastctx = tokenizer.decode(txt)
|
||||
vars.lastctx = utils.decodenewlines(tokenizer.decode(txt))
|
||||
|
||||
# Clear CUDA cache if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
|
@ -3016,7 +3014,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
|
||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
||||
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||
|
||||
execute_outmod()
|
||||
if(vars.lua_koboldbridge.regeneration_required):
|
||||
|
@ -3026,7 +3024,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
|||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
else:
|
||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
||||
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))} for tokens in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
|
@ -3239,7 +3237,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||
|
||||
if not vars.quiet:
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||
|
||||
vars._actions = vars.actions
|
||||
vars._prompt = vars.prompt
|
||||
|
@ -3282,7 +3280,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
|
||||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(past[i])
|
||||
txt = utils.decodenewlines(tokenizer.decode(past[i]))
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
|
@ -3334,7 +3332,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
return
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
|
||||
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i]))
|
||||
genout = past
|
||||
|
||||
execute_outmod()
|
||||
|
@ -3345,7 +3343,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||
assert type(genout[-1]["generated_text"]) is str
|
||||
else:
|
||||
genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout]
|
||||
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout]
|
||||
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
|
@ -3373,14 +3371,14 @@ def getnewcontent(txt):
|
|||
return txt
|
||||
|
||||
# Tokenize the last context and the generated content
|
||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True)
|
||||
txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True)
|
||||
ctxtokens = tokenizer.encode(utils.encodenewlines(vars.lastctx), max_length=int(2e9), truncation=True)
|
||||
txttokens = tokenizer.encode(utils.encodenewlines(txt), max_length=int(2e9), truncation=True)
|
||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||
|
||||
# Remove the context from the returned text
|
||||
newtokens = txttokens[dif:]
|
||||
|
||||
return tokenizer.decode(newtokens)
|
||||
return utils.decodenewlines(tokenizer.decode(newtokens))
|
||||
|
||||
#==================================================================#
|
||||
# Applies chosen formatting options to text submitted to AI
|
||||
|
@ -3396,9 +3394,6 @@ def applyinputformatting(txt):
|
|||
# Applies chosen formatting options to text returned from AI
|
||||
#==================================================================#
|
||||
def applyoutputformatting(txt):
|
||||
# Revert S mode on output to maintain compatibility
|
||||
txt = txt.replace('</s>', "\n")
|
||||
|
||||
# Use standard quotes and apostrophes
|
||||
txt = utils.fixquotes(txt)
|
||||
|
||||
|
@ -4856,8 +4851,8 @@ loadsettings()
|
|||
def __preempt_tokenizer():
|
||||
if("tokenizer" not in globals()):
|
||||
return
|
||||
tokenizer.decode([25678, 559])
|
||||
tokenizer.encode("eunoia")
|
||||
utils.decodenewlines(tokenizer.decode([25678, 559]))
|
||||
tokenizer.encode(utils.encodenewlines("eunoia"))
|
||||
threading.Thread(target=__preempt_tokenizer).start()
|
||||
|
||||
# Precompile TPU backend if required
|
||||
|
@ -4891,7 +4886,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
|||
def send_debug():
|
||||
if vars.debug:
|
||||
debug_info = ""
|
||||
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
|
||||
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]], ["Newline Mode", vars.newlinemode]:
|
||||
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
||||
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
||||
|
||||
|
|
19
utils.py
19
utils.py
|
@ -1,6 +1,8 @@
|
|||
from threading import Timer
|
||||
import re
|
||||
|
||||
vars = None
|
||||
|
||||
#==================================================================#
|
||||
# Decorator to prevent a function's actions from being run until
|
||||
# at least x seconds have passed without the function being called
|
||||
|
@ -111,8 +113,15 @@ def cleanfilename(filename):
|
|||
filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
|
||||
return filename
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#==================================================================#
|
||||
# Newline substitution for fairseq models
|
||||
#==================================================================#
|
||||
def encodenewlines(txt):
|
||||
if(vars.newlinemode == "s"):
|
||||
return txt.replace('\n', "</s>")
|
||||
return txt
|
||||
|
||||
def decodenewlines(txt):
|
||||
if(vars.newlinemode == "s"):
|
||||
return txt.replace("</s>", '\n')
|
||||
return txt
|
||||
|
|
Loading…
Reference in New Issue