Fix settings callback, and `genout.shape[-1]` in `tpumtjgenerate()`

This commit is contained in:
Gnome Ann 2022-01-17 14:52:29 -05:00
parent 293b75e89f
commit 703c092577
1 changed files with 2 additions and 1 deletions

View File

@ -1087,6 +1087,7 @@ else:
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
tpu_mtj_backend.load_model(vars.custmodpth) tpu_mtj_backend.load_model(vars.custmodpth)
vars.allowsp = True vars.allowsp = True
vars.modeldim = int(tpu_mtj_backend.params["d_model"]) vars.modeldim = int(tpu_mtj_backend.params["d_model"])
@ -3068,7 +3069,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
past = genout past = genout
for i in range(vars.numseqs): for i in range(vars.numseqs):
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout.shape[-1] vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1]
except Exception as e: except Exception as e:
if(issubclass(type(e), lupa.LuaError)): if(issubclass(type(e), lupa.LuaError)):