diff --git a/aiserver.py b/aiserver.py index 82d046ec..337413bc 100644 --- a/aiserver.py +++ b/aiserver.py @@ -211,6 +211,8 @@ class vars: quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display +utils.vars = vars + #==================================================================# # Function to get model selection at startup #==================================================================# @@ -916,7 +918,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme return self.regeneration_required or self.halt tail = input_ids[..., -vars.generated_tkns:] for i, t in enumerate(tail): - decoded = tokenizer.decode(t) + decoded = utils.decodenewlines(tokenizer.decode(t)) _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) found -= self.excluded_world_info[i] if(len(found) != 0): @@ -1118,7 +1120,7 @@ else: return excluded_world_info, regeneration_required, halt for i, t in enumerate(generated): - decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) + decoded = utils.decodenewlines(tokenizer.decode(past[i])) + utils.decodenewlines(tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])) _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) found -= excluded_world_info[i] if(len(found) != 0): @@ -1327,7 +1329,7 @@ def lua_decode(tokens): from transformers import GPT2TokenizerFast global tokenizer tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") - return tokenizer.decode(tokens) + return utils.decodenewlines(tokenizer.decode(tokens)) #==================================================================# # Encode string into list of token IDs using current tokenizer @@ -1339,7 +1341,7 @@ def lua_encode(string): from transformers import GPT2TokenizerFast global tokenizer tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") - return tokenizer.encode(string, max_length=int(4e9), truncation=True) + return tokenizer.encode(utils.encodenewlines(string), max_length=int(4e9), truncation=True) #==================================================================# # Computes context given a submission, Lua array of entry UIDs and a Lua array @@ -1379,7 +1381,7 @@ def lua_compute_context(submission, entries, folders, kwargs): anotetxt, actions, ) - return tokenizer.decode(txt) + return utils.decodenewlines(tokenizer.decode(txt)) #==================================================================# # Get property of a world info entry given its UID and property name @@ -2455,10 +2457,6 @@ def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, if(len(data)): data = f"\n{vars.chatname} : {data}\n" - # mode - if(vars.newlinemode == "s"): - data = data.replace('\n', "") - # If we're not continuing, store a copy of the raw input if(data != ""): vars.lastact = data @@ -2706,21 +2704,21 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache/") # Calculate token budget - prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=int(2e9), truncation=True) + prompttkns = tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', vars.prompt)), max_length=int(2e9), truncation=True) lnprompt = len(prompttkns) - memtokens = tokenizer.encode(mem, max_length=int(2e9), truncation=True) + memtokens = tokenizer.encode(utils.encodenewlines(mem), max_length=int(2e9), truncation=True) lnmem = len(memtokens) if(lnmem > vars.max_length - lnsp - vars.genamt - budget_deduction): raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") - witokens = tokenizer.encode(winfo, max_length=int(2e9), truncation=True) + witokens = tokenizer.encode(utils.encodenewlines(winfo), max_length=int(2e9), truncation=True) lnwi = len(witokens) if(lnmem + lnwi > vars.max_length - lnsp - vars.genamt - budget_deduction): raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") if(anotetxt != ""): - anotetkns = tokenizer.encode(anotetxt, max_length=int(2e9), truncation=True) + anotetkns = tokenizer.encode(utils.encodenewlines(anotetxt), max_length=int(2e9), truncation=True) lnanote = len(anotetkns) if(lnmem + lnwi + lnanote > vars.max_length - lnsp - vars.genamt - budget_deduction): raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.") @@ -2730,7 +2728,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, else: budget = vars.max_length - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction - lnsubmission = len(tokenizer.encode(vars.comregex_ai.sub('', submission), max_length=int(2e9), truncation=True)) if submission is not None else 0 + lnsubmission = len(tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', submission)), max_length=int(2e9), truncation=True)) if submission is not None else 0 maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0 if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnsp - vars.genamt - budget_deduction): @@ -2759,7 +2757,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, assert budget >= 0 if(budget <= 0): break - acttkns = tokenizer.encode(chunk, max_length=int(2e9), truncation=True) + acttkns = tokenizer.encode(utils.encodenewlines(chunk), max_length=int(2e9), truncation=True) tknlen = len(acttkns) if(tknlen < budget): tokens = acttkns + tokens @@ -2818,18 +2816,18 @@ def calcsubmit(txt): if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): - sendtocolab(tokenizer.decode(subtxt), min, max) + sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): - oairequest(tokenizer.decode(subtxt), min, max) + oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) else: if(not vars.model in ["Colab", "OAI", "TPUMeshTransformerGPTJ"]): generate(subtxt, min, max, found_entries=found_entries) elif(vars.model == "Colab"): - sendtocolab(tokenizer.decode(subtxt), min, max) + sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "OAI"): - oairequest(tokenizer.decode(subtxt), min, max) + oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max) elif(vars.model == "TPUMeshTransformerGPTJ"): tpumtjgenerate(subtxt, min, max, found_entries=found_entries) @@ -2947,7 +2945,7 @@ def _generate(txt, minimum, maximum, found_entries): genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1] encoded = [] for i in range(vars.numseqs): - txt = tokenizer.decode(genout[i, -already_generated:]) + txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) found_entries[i].update(_found_entries) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) @@ -2986,10 +2984,10 @@ def generate(txt, minimum, maximum, found_entries=None): found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) if not vars.quiet: - print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) + print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END)) # Store context in memory to use it for comparison with generated content - vars.lastctx = tokenizer.decode(txt) + vars.lastctx = utils.decodenewlines(tokenizer.decode(txt)) # Clear CUDA cache if using GPU if(vars.hascuda and (vars.usegpu or vars.breakmodel)): @@ -3016,7 +3014,7 @@ def generate(txt, minimum, maximum, found_entries=None): for i in range(vars.numseqs): vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item()) - vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i, -already_generated:]) + vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:])) execute_outmod() if(vars.lua_koboldbridge.regeneration_required): @@ -3026,7 +3024,7 @@ def generate(txt, minimum, maximum, found_entries=None): genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) assert type(genout[-1]["generated_text"]) is str else: - genout = [{"generated_text": tokenizer.decode(tokens[-already_generated:])} for tokens in genout] + genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))} for tokens in genout] if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -3239,7 +3237,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) if not vars.quiet: - print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) + print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, utils.decodenewlines(tokenizer.decode(txt)), colors.END)) vars._actions = vars.actions vars._prompt = vars.prompt @@ -3282,7 +3280,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): encoded = [] for i in range(vars.numseqs): - txt = tokenizer.decode(past[i]) + txt = utils.decodenewlines(tokenizer.decode(past[i])) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) found_entries[i].update(_found_entries) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) @@ -3334,7 +3332,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): return for i in range(vars.numseqs): - vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i]) + vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i])) genout = past execute_outmod() @@ -3345,7 +3343,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]}) assert type(genout[-1]["generated_text"]) is str else: - genout = [{"generated_text": tokenizer.decode(txt)} for txt in genout] + genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout] if(len(genout) == 1): genresult(genout[0]["generated_text"]) @@ -3373,14 +3371,14 @@ def getnewcontent(txt): return txt # Tokenize the last context and the generated content - ctxtokens = tokenizer.encode(vars.lastctx, max_length=int(2e9), truncation=True) - txttokens = tokenizer.encode(txt, max_length=int(2e9), truncation=True) + ctxtokens = tokenizer.encode(utils.encodenewlines(vars.lastctx), max_length=int(2e9), truncation=True) + txttokens = tokenizer.encode(utils.encodenewlines(txt), max_length=int(2e9), truncation=True) dif = (len(txttokens) - len(ctxtokens)) * -1 # Remove the context from the returned text newtokens = txttokens[dif:] - return tokenizer.decode(newtokens) + return utils.decodenewlines(tokenizer.decode(newtokens)) #==================================================================# # Applies chosen formatting options to text submitted to AI @@ -3396,9 +3394,6 @@ def applyinputformatting(txt): # Applies chosen formatting options to text returned from AI #==================================================================# def applyoutputformatting(txt): - # Revert S mode on output to maintain compatibility - txt = txt.replace('', "\n") - # Use standard quotes and apostrophes txt = utils.fixquotes(txt) @@ -4856,8 +4851,8 @@ loadsettings() def __preempt_tokenizer(): if("tokenizer" not in globals()): return - tokenizer.decode([25678, 559]) - tokenizer.encode("eunoia") + utils.decodenewlines(tokenizer.decode([25678, 559])) + tokenizer.encode(utils.encodenewlines("eunoia")) threading.Thread(target=__preempt_tokenizer).start() # Precompile TPU backend if required @@ -4891,7 +4886,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)): def send_debug(): if vars.debug: debug_info = "" - for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]: + for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]: debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1]) emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True) diff --git a/utils.py b/utils.py index 31184ed7..1c44c27f 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,8 @@ from threading import Timer import re +vars = None + #==================================================================# # Decorator to prevent a function's actions from being run until # at least x seconds have passed without the function being called @@ -111,8 +113,15 @@ def cleanfilename(filename): filename = "".join(c for c in filename if c not in filteredcharacters).rstrip() return filename - - - - - \ No newline at end of file +#==================================================================# +# Newline substitution for fairseq models +#==================================================================# +def encodenewlines(txt): + if(vars.newlinemode == "s"): + return txt.replace('\n', "") + return txt + +def decodenewlines(txt): + if(vars.newlinemode == "s"): + return txt.replace("", '\n') + return txt