Use `vars._actions` in `tpumtjgenerate` and its callbacks
This commit is contained in:
parent
45bfde8d5d
commit
6502af086f
|
@ -1048,7 +1048,7 @@ else:
|
||||||
|
|
||||||
for i, t in enumerate(generated):
|
for i, t in enumerate(generated):
|
||||||
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
||||||
_, found = checkworldinfo(decoded, force_use_txt=True)
|
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
|
||||||
found -= excluded_world_info[i]
|
found -= excluded_world_info[i]
|
||||||
if(len(found) != 0):
|
if(len(found) != 0):
|
||||||
regeneration_required = True
|
regeneration_required = True
|
||||||
|
@ -3033,7 +3033,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
encoded = []
|
encoded = []
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
txt = tokenizer.decode(past[i])
|
txt = tokenizer.decode(past[i])
|
||||||
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
|
||||||
found_entries[i].update(_found_entries)
|
found_entries[i].update(_found_entries)
|
||||||
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
encoded.append(np.array(txt, dtype=np.uint32))
|
encoded.append(np.array(txt, dtype=np.uint32))
|
||||||
|
|
Loading…
Reference in New Issue