Shallow copy story chunks when generating

This commit is contained in:
Gnome Ann
2021-11-03 17:53:38 -04:00
parent b8c3d8c12e
commit 9b18068999

View File

@ -1223,7 +1223,7 @@ def calcsubmitbudgetheader(txt, **kwargs):
return winfo, mem, anotetxt, found_entries return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt): def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions=vars.actions):
anotetkns = [] # Placeholder for Author's Note tokens anotetkns = [] # Placeholder for Author's Note tokens
lnanote = 0 # Placeholder for Author's Note length lnanote = 0 # Placeholder for Author's Note length
@ -1262,8 +1262,8 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt):
# Get most recent action tokens up to our budget # Get most recent action tokens up to our budget
n = 0 n = 0
for key in reversed(vars.actions): for key in reversed(actions):
chunk = vars.actions[key] chunk = actions[key]
if(budget <= 0): if(budget <= 0):
break break
@ -1432,6 +1432,10 @@ def generate(txt, min, max):
model.kai_scanner_head_length = gen_in.shape[-1] model.kai_scanner_head_length = gen_in.shape[-1]
model.kai_scanner_excluded_world_info = set() model.kai_scanner_excluded_world_info = set()
actions = vars.actions
if(vars.dynamicscan):
actions = actions.copy()
with torch.no_grad(): with torch.no_grad():
already_generated = 0 already_generated = 0
numseqs = vars.numseqs if not vars.dynamicscan else 1 numseqs = vars.numseqs if not vars.dynamicscan else 1
@ -1458,7 +1462,7 @@ def generate(txt, min, max):
txt = tokenizer.decode(genout[0, -already_generated:]) txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(vars.actions), winfo, mem, anotetxt) txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions=actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device) encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat( genout = torch.cat(
( (