Use `vars._actions` in `tpumtjgenerate` and its callbacks
This commit is contained in:
parent
45bfde8d5d
commit
6502af086f
|
@ -1048,7 +1048,7 @@ else:
|
|||
|
||||
for i, t in enumerate(generated):
|
||||
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||
found -= excluded_world_info[i]
|
||||
if(len(found) != 0):
|
||||
regeneration_required = True
|
||||
|
@ -3033,7 +3033,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
encoded = []
|
||||
for i in range(vars.numseqs):
|
||||
txt = tokenizer.decode(past[i])
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||
found_entries[i].update(_found_entries)
|
||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||
encoded.append(np.array(txt, dtype=np.uint32))
|
||||
|
|
Loading…
Reference in New Issue