diff --git a/aiserver.py b/aiserver.py index d25dfe11..60e1a4a5 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1048,7 +1048,7 @@ else: for i, t in enumerate(generated): decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) - _, found = checkworldinfo(decoded, force_use_txt=True) + _, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions) found -= excluded_world_info[i] if(len(found) != 0): regeneration_required = True @@ -3033,7 +3033,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): encoded = [] for i in range(vars.numseqs): txt = tokenizer.decode(past[i]) - winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions) found_entries[i].update(_found_entries) txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) encoded.append(np.array(txt, dtype=np.uint32))