From 703c0925779e796288150b3ecd8f3a98e05393b0 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Mon, 17 Jan 2022 14:52:29 -0500 Subject: [PATCH] Fix settings callback, and `genout.shape[-1]` in `tpumtjgenerate()` --- aiserver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aiserver.py b/aiserver.py index 8da79094..7628cb18 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1087,6 +1087,7 @@ else: tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback + tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) @@ -3068,7 +3069,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): past = genout for i in range(vars.numseqs): vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) - vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout.shape[-1] + vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1] except Exception as e: if(issubclass(type(e), lupa.LuaError)):