Make sure calcsubmitbudget uses the correct reference to vars.actions

This commit is contained in:
Gnome Ann 2021-11-03 18:57:02 -04:00
parent a2d7735a51
commit 81bd058caf

View File

@ -1223,7 +1223,7 @@ def calcsubmitbudgetheader(txt, **kwargs):
return winfo, mem, anotetxt, found_entries
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions=vars.actions):
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions):
forceanote = False # In case we don't have enough actions to hit A.N. depth
anoteadded = False # In case our budget runs out before we hit A.N. depth
anotetkns = [] # Placeholder for Author's Note tokens
@ -1322,7 +1322,7 @@ def calcsubmit(txt):
# For all transformers models
if(vars.model != "InferKit"):
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt)
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions)
if(actionlen == 0):
if(not vars.model in ["Colab", "OAI"]):
generate(subtxt, min, max, found_entries=found_entries)
@ -1462,7 +1462,7 @@ def generate(txt, min, max, found_entries=set()):
txt = tokenizer.decode(genout[0, -already_generated:])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
found_entries |= _found_entries
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions=actions)
txt, _, _ = calcsubmitbudget(len(actions), winfo, mem, anotetxt, actions)
encoded = tokenizer.encode(txt, return_tensors="pt", truncation=True).long().to(genout.device)
genout = torch.cat(
(