Use `vars._actions` in `tpumtjgenerate` and its callbacks

This commit is contained in:
Gnome Ann 2022-01-17 13:24:11 -05:00
parent 45bfde8d5d
commit 6502af086f
1 changed files with 2 additions and 2 deletions

View File

@ -1048,7 +1048,7 @@ else:
for i, t in enumerate(generated):
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
_, found = checkworldinfo(decoded, force_use_txt=True)
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
found -= excluded_world_info[i]
if(len(found) != 0):
regeneration_required = True
@ -3033,7 +3033,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
encoded = []
for i in range(vars.numseqs):
txt = tokenizer.decode(past[i])
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
found_entries[i].update(_found_entries)
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
encoded.append(np.array(txt, dtype=np.uint32))