Fix fairseq newline handling issues
This commit is contained in:
parent
c1af8f72c3
commit
f682c1229a
69
aiserver.py
69
aiserver.py
|
@ -211,6 +211,8 @@ class vars:
|
||||||
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
|
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
|
||||||
debug = False # If set to true, will send debug information to the client for display
|
debug = False # If set to true, will send debug information to the client for display
|
||||||
|
|
||||||
|
utils.vars = vars
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Function to get model selection at startup
|
# Function to get model selection at startup
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@ -916,7 +918,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||||
return self.regeneration_required or self.halt
|
return self.regeneration_required or self.halt
|
||||||
tail = input_ids[..., -vars.generated_tkns:]
|
tail = input_ids[..., -vars.generated_tkns:]
|
||||||
for i, t in enumerate(tail):
|
for i, t in enumerate(tail):
|
||||||
decoded = tokenizer.decode(t)
|
decoded = utils.decodenewlines(tokenizer.decode(t))
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||||
found -= self.excluded_world_info[i]
|
found -= self.excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
|
@ -1118,7 +1120,7 @@ else:
|
||||||
return excluded_world_info, regeneration_required, halt
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
for i, t in enumerate(generated):
|
for i, t in enumerate(generated):
|
||||||
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
decoded = utils.decodenewlines(tokenizer.decode(past[i])) + utils.decodenewlines(tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]))
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||||
found -= excluded_world_info[i]
|
found -= excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
|
@ -1327,7 +1329,7 @@ def lua_decode(tokens):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
return tokenizer.decode(tokens)
|
return utils.decodenewlines(tokenizer.decode(tokens))
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Encode string into list of token IDs using current tokenizer
|
# Encode string into list of token IDs using current tokenizer
|
||||||
|
@ -1339,7 +1341,7 @@ def lua_encode(string):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
global tokenizer
|
global tokenizer
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
return tokenizer.encode(string, max_length=int(4e9), truncation=True)
|
return tokenizer.encode(utils.encodenewlines(string), max_length=int(4e9), truncation=True)
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
# Computes context given a submission, Lua array of entry UIDs and a Lua array
|
||||||
|
@ -1379,7 +1381,7 @@ def lua_compute_context(submission, entries, folders, kwargs):
|
||||||
anotetxt,
|
anotetxt,
|
||||||
actions,
|
actions,
|
||||||
)
|
)
|
||||||
return tokenizer.decode(txt)
|
return utils.decodenewlines(tokenizer.decode(txt))
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Get property of a world info entry given its UID and property name
|
# Get property of a world info entry given its UID and property name
|
||||||
|
@ -2455,10 +2457,6 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False,
|
||||||
if(len(data)):
|
if(len(data)):
|
||||||
data = f"\n{vars.chatname} : {data}\n"
|
data = f"\n{vars.chatname} : {data}\n"
|
||||||
|
|
||||||
# </s> mode
|
|
||||||
if(vars.newlinemode == "s"):
|
|
||||||
data = data.replace('\n', "</s>")
|
|
||||||
|
|
||||||
# If we're not continuing, store a copy of the raw input
|
# If we're not continuing, store a copy of the raw input
|
||||||
if(data != ""):
|
if(data != ""):
|
||||||
vars.lastact = data
|
vars.lastact = data
|
||||||
|
@ -2706,21 +2704,21 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/")
|
||||||
|
|
||||||
# Calculate token budget
|
# Calculate token budget
|
||||||
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True)
|
prompttkns = tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', vars.prompt)), max_length=int(2e9), truncation=True)
|
||||||
lnprompt = len(prompttkns)
|
lnprompt = len(prompttkns)
|
||||||
|
|
||||||
memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True)
|
memtokens = tokenizer.encode(utils.encodenewlines(mem), max_length=int(2e9), truncation=True)
|
||||||
lnmem = len(memtokens)
|
lnmem = len(memtokens)
|
||||||
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
|
||||||
witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True)
|
witokens = tokenizer.encode(utils.encodenewlines(winfo), max_length=int(2e9), truncation=True)
|
||||||
lnwi = len(witokens)
|
lnwi = len(witokens)
|
||||||
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
|
||||||
if(anotetxt != ""):
|
if(anotetxt != ""):
|
||||||
anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True)
|
anotetkns = tokenizer.encode(utils.encodenewlines(anotetxt), max_length=int(2e9), truncation=True)
|
||||||
lnanote = len(anotetkns)
|
lnanote = len(anotetkns)
|
||||||
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
|
||||||
|
@ -2730,7 +2728,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
||||||
else:
|
else:
|
||||||
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
|
||||||
|
|
||||||
lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
lnsubmission = len(tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', submission)), max_length=int(2e9), truncation=True)) if submission is not None else 0
|
||||||
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
|
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
|
||||||
|
|
||||||
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction):
|
||||||
|
@ -2759,7 +2757,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
|
||||||
assert budget >= 0
|
assert budget >= 0
|
||||||
if(budget <= 0):
|
if(budget <= 0):
|
||||||
break
|
break
|
||||||
acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True)
|
acttkns = tokenizer.encode(utils.encodenewlines(chunk), max_length=int(2e9), truncation=True)
|
||||||
tknlen = len(acttkns)
|
tknlen = len(acttkns)
|
||||||
if(tknlen < budget):
|
if(tknlen < budget):
|
||||||
tokens = acttkns + tokens
|
tokens = acttkns + tokens
|
||||||
|
@ -2818,18 +2816,18 @@ def calcsubmit(txt):
|
||||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(tokenizer.decode(subtxt), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
else:
|
else:
|
||||||
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]):
|
||||||
generate(subtxt, min, max, found_entries=found_entries)
|
generate(subtxt, min, max, found_entries=found_entries)
|
||||||
elif(vars.model == "Colab"):
|
elif(vars.model == "Colab"):
|
||||||
sendtocolab(tokenizer.decode(subtxt), min, max)
|
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "OAI"):
|
elif(vars.model == "OAI"):
|
||||||
oairequest(tokenizer.decode(subtxt), min, max)
|
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
|
||||||
|
|
||||||
|
@ -2947,7 +2945,7 @@ def _generate(txt, minimum, maximum, found_entries):
|
||||||
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(genout[i, -already_generated:])
|
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
|
@ -2986,10 +2984,10 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
if not vars.quiet:
|
if not vars.quiet:
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||||
|
|
||||||
# Store context in memory to use it for comparison with generated content
|
# Store context in memory to use it for comparison with generated content
|
||||||
vars.lastctx = tokenizer.decode(txt)
|
vars.lastctx = utils.decodenewlines(tokenizer.decode(txt))
|
||||||
|
|
||||||
# Clear CUDA cache if using GPU
|
# Clear CUDA cache if using GPU
|
||||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||||
|
@ -3016,7 +3014,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
|
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
|
||||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:])
|
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
if(vars.lua_koboldbridge.regeneration_required):
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
|
@ -3026,7 +3024,7 @@ def generate(txt, minimum, maximum, found_entries=None):
|
||||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
assert type(genout[-1]["generated_text"]) is str
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
else:
|
else:
|
||||||
genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout]
|
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))} for tokens in genout]
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
|
@ -3239,7 +3237,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
if not vars.quiet:
|
if not vars.quiet:
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END))
|
||||||
|
|
||||||
vars._actions = vars.actions
|
vars._actions = vars.actions
|
||||||
vars._prompt = vars.prompt
|
vars._prompt = vars.prompt
|
||||||
|
@ -3282,7 +3280,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
|
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(past[i])
|
txt = utils.decodenewlines(tokenizer.decode(past[i]))
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
|
@ -3334,7 +3332,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
|
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i]))
|
||||||
genout = past
|
genout = past
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
|
@ -3345,7 +3343,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
|
||||||
assert type(genout[-1]["generated_text"]) is str
|
assert type(genout[-1]["generated_text"]) is str
|
||||||
else:
|
else:
|
||||||
genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout]
|
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout]
|
||||||
|
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
|
@ -3373,14 +3371,14 @@ def getnewcontent(txt):
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
# Tokenize the last context and the generated content
|
# Tokenize the last context and the generated content
|
||||||
ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True)
|
ctxtokens = tokenizer.encode(utils.encodenewlines(vars.lastctx), max_length=int(2e9), truncation=True)
|
||||||
txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True)
|
txttokens = tokenizer.encode(utils.encodenewlines(txt), max_length=int(2e9), truncation=True)
|
||||||
dif = (len(txttokens) - len(ctxtokens)) * -1
|
dif = (len(txttokens) - len(ctxtokens)) * -1
|
||||||
|
|
||||||
# Remove the context from the returned text
|
# Remove the context from the returned text
|
||||||
newtokens = txttokens[dif:]
|
newtokens = txttokens[dif:]
|
||||||
|
|
||||||
return tokenizer.decode(newtokens)
|
return utils.decodenewlines(tokenizer.decode(newtokens))
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Applies chosen formatting options to text submitted to AI
|
# Applies chosen formatting options to text submitted to AI
|
||||||
|
@ -3396,9 +3394,6 @@ def applyinputformatting(txt):
|
||||||
# Applies chosen formatting options to text returned from AI
|
# Applies chosen formatting options to text returned from AI
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def applyoutputformatting(txt):
|
def applyoutputformatting(txt):
|
||||||
# Revert S mode on output to maintain compatibility
|
|
||||||
txt = txt.replace('</s>', "\n")
|
|
||||||
|
|
||||||
# Use standard quotes and apostrophes
|
# Use standard quotes and apostrophes
|
||||||
txt = utils.fixquotes(txt)
|
txt = utils.fixquotes(txt)
|
||||||
|
|
||||||
|
@ -4856,8 +4851,8 @@ loadsettings()
|
||||||
def __preempt_tokenizer():
|
def __preempt_tokenizer():
|
||||||
if("tokenizer" not in globals()):
|
if("tokenizer" not in globals()):
|
||||||
return
|
return
|
||||||
tokenizer.decode([25678, 559])
|
utils.decodenewlines(tokenizer.decode([25678, 559]))
|
||||||
tokenizer.encode("eunoia")
|
tokenizer.encode(utils.encodenewlines("eunoia"))
|
||||||
threading.Thread(target=__preempt_tokenizer).start()
|
threading.Thread(target=__preempt_tokenizer).start()
|
||||||
|
|
||||||
# Precompile TPU backend if required
|
# Precompile TPU backend if required
|
||||||
|
@ -4891,7 +4886,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||||
def send_debug():
|
def send_debug():
|
||||||
if vars.debug:
|
if vars.debug:
|
||||||
debug_info = ""
|
debug_info = ""
|
||||||
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
|
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]], ["Newline Mode", vars.newlinemode]:
|
||||||
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
||||||
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
||||||
|
|
||||||
|
|
17
utils.py
17
utils.py
|
@ -1,6 +1,8 @@
|
||||||
from threading import Timer
|
from threading import Timer
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
vars = None
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# Decorator to prevent a function's actions from being run until
|
# Decorator to prevent a function's actions from being run until
|
||||||
# at least x seconds have passed without the function being called
|
# at least x seconds have passed without the function being called
|
||||||
|
@ -111,8 +113,15 @@ def cleanfilename(filename):
|
||||||
filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
|
filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
#==================================================================#
|
||||||
|
# Newline substitution for fairseq models
|
||||||
|
#==================================================================#
|
||||||
|
def encodenewlines(txt):
|
||||||
|
if(vars.newlinemode == "s"):
|
||||||
|
return txt.replace('\n', "</s>")
|
||||||
|
return txt
|
||||||
|
|
||||||
|
def decodenewlines(txt):
|
||||||
|
if(vars.newlinemode == "s"):
|
||||||
|
return txt.replace("</s>", '\n')
|
||||||
|
return txt
|
||||||
|
|
Loading…
Reference in New Issue