Merge pull request #43 from VE-FORBRYDERNE/dynamic-scan-patch

Dynamic scan patch
This commit is contained in:
henk717 2021-12-15 09:45:07 +01:00 committed by GitHub
commit 46b0473229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 13 deletions

View File

@ -21,6 +21,7 @@ import collections
import zipfile
import packaging
import contextlib
import traceback
from typing import Any, Union, Dict, Set, List
import requests
@ -744,6 +745,11 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if(vars.lua_koboldbridge.generated_cols >= vars.genamt):
self.regeneration_required = False
self.halt = False
return True
assert input_ids.ndim == 2
assert len(self.excluded_world_info) == input_ids.shape[0]
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
@ -773,7 +779,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
excluded_world_info=self.kai_scanner_excluded_world_info,
head_length=self.kai_scanner_head_length,
)
stopping_criteria.append(self.kai_scanner)
stopping_criteria.insert(0, self.kai_scanner)
return stopping_criteria
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
@ -1995,17 +2001,17 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
lnanote = 0 # Placeholder for Author's Note length
# Calculate token budget
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt))
prompttkns = tokenizer.encode(vars.comregex_ai.sub('', vars.prompt), max_length=1+int(vars.max_length), truncation=True)
lnprompt = len(prompttkns)
memtokens = tokenizer.encode(mem)
memtokens = tokenizer.encode(mem, max_length=1+int(vars.max_length), truncation=True)
lnmem = len(memtokens)
witokens = tokenizer.encode(winfo)
witokens = tokenizer.encode(winfo, max_length=1+int(vars.max_length), truncation=True)
lnwi = len(witokens)
if(anotetxt != ""):
anotetkns = tokenizer.encode(anotetxt)
anotetkns = tokenizer.encode(anotetxt, max_length=1+int(vars.max_length), truncation=True)
lnanote = len(anotetkns)
lnsp = vars.sp.shape[0] if vars.sp is not None else 0
@ -2034,7 +2040,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
if(budget <= 0):
break
acttkns = tokenizer.encode(chunk)
acttkns = tokenizer.encode(chunk, max_length=int(vars.max_length), truncation=True)
tknlen = len(acttkns)
if(tknlen < budget):
tokens = acttkns + tokens
@ -2168,7 +2174,7 @@ def calcsubmit(txt):
#==================================================================#
def _generate(txt, minimum, maximum, found_entries):
gen_in = tokenizer.encode(txt, return_tensors="pt", truncation=True).long()
gen_in = tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True).long()
if(vars.sp is not None):
soft_tokens = torch.arange(
model.config.vocab_size,
@ -2214,14 +2220,14 @@ def _generate(txt, minimum, maximum, found_entries):
for r in range(vars.numseqs):
for c in range(already_generated):
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
genout[r][genout.shape[-1] - already_generated - c] = vars.lua_koboldbridge.generated[r+1][c+1]
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(genout[i, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded.append(tokenizer.encode(txt, return_tensors="pt", truncation=True)[0].long().to(genout.device))
encoded.append(tokenizer.encode(txt, return_tensors="pt", max_length=int(vars.max_length), truncation=True)[0].long().to(genout.device))
max_length = len(max(encoded, key=len))
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
genout = torch.cat(
@ -2275,7 +2281,7 @@ def generate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
else:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
@ -2486,7 +2492,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
print("{0}{1}{2}".format(colors.YELLOW, "Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.", colors.END), file=sys.stderr)
else:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, str(e).replace("\033", ""), colors.END), file=sys.stderr)
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
set_aibusy(0)
return
@ -2527,8 +2533,8 @@ def getnewcontent(txt):
return txt
# Tokenize the last context and the generated content
ctxtokens = tokenizer.encode(vars.lastctx)
txttokens = tokenizer.encode(txt)
ctxtokens = tokenizer.encode(vars.lastctx, max_length=1+int(vars.max_length), truncation=True)
txttokens = tokenizer.encode(txt, max_length=1+int(vars.max_length), truncation=True)
dif = (len(txttokens) - len(ctxtokens)) * -1
# Remove the context from the returned text