Merge pull request #43 from VE-FORBRYDERNE/dynamic-scan-patch

Dynamic scan patch
This commit is contained in:
henk717
2021-12-15 09:45:07 +01:00
committed by GitHub

View File

@ -21,6 +21,7 @@ import collections
import zipfile import zipfile
import packaging import packaging
import contextlib import contextlib
import traceback
from typing import Any, Union, Dict, Set, List from typing import Any, Union, Dict, Set, List
import requests import requests
@ -744,6 +745,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
scores: torch.FloatTensor, scores: torch.FloatTensor,
**kwargs, **kwargs,
) -> bool: ) -> bool:
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
self.regeneration_required = False
self.halt = False
return True
assert input_ids.ndim == 2 assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0] assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = vars.lua_koboldbridge.regeneration_required self.regeneration_required = vars.lua_koboldbridge.regeneration_required
@ -773,7 +779,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
excluded_world_info=self.kai_scanner_excluded_world_info, excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length, head_length=self.kai_scanner_head_length,
) )
stopping_criteria.append(self.kai_scanner) stopping_criteria.insert(0, self.kai_scanner)
return stopping_criteria return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
@ -1995,17 +2001,17 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
lnanote = 0 # Placeholder for Author's Note length lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget # Calculate token budget
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt)) prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
lnprompt = len(prompttkns) lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem) memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
lnmem = len(memtokens) lnmem = len(memtokens)
witokens = tokenizer.encode(winfo) witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
lnwi = len(witokens) lnwi = len(witokens)
if(anotetxt != ""): if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt) anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
lnanote = len(anotetkns) lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0 lnsp = vars.sp.shape[0] if vars.sp is not None else 0
@ -2034,7 +2040,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
if(budget <= 0): if(budget <= 0):
break break
acttkns = tokenizer.encode(chunk) acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
tknlen = len(acttkns) tknlen = len(acttkns)
if(tknlen < budget): if(tknlen < budget):
tokens = acttkns + tokens tokens = acttkns + tokens
@ -2168,7 +2174,7 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
def _generate(txt, minimum, maximum, found_entries): def _generate(txt, minimum, maximum, found_entries):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long() gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
if(vars.sp is not None): if(vars.sp is not None):
soft_tokens = torch.arange( soft_tokens = torch.arange(
model.config.vocab_size, model.config.vocab_size,
@ -2214,14 +2220,14 @@ def _generate(txt, minimum, maximum, found_entries):
for r in range(vars.numseqs): for r in range(vars.numseqs):
for c in range(already_generated): for c in range(already_generated):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1] genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
encoded = [] encoded = []
for i in range(vars.numseqs): for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:]) txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries) found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions) txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device)) encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len)) max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded)) encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat( genout = torch.cat(
@ -2275,7 +2281,7 @@ def generate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr) print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
else: else:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr) print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0) set_aibusy(0)
return return
@ -2486,7 +2492,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr) print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
else: else:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr) print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0) set_aibusy(0)
return return
@ -2527,8 +2533,8 @@ def getnewcontent(txt):
return txt return txt
# Tokenize the last context and the generated content # Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(vars.lastctx) ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
txttokens = tokenizer.encode(txt) txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
dif = (len(txttokens) - len(ctxtokens)) * -1 dif = (len(txttokens) - len(ctxtokens)) * -1
# Remove the context from the returned text # Remove the context from the returned text