diff --git a/aiserver.py b/aiserver.py index 8da79094..7628cb18 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1087,6 +1087,7 @@ else: tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback + tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback tpu_mtj_backend.load_model(vars.custmodpth) vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) @@ -3068,7 +3069,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): past = genout for i in range(vars.numseqs): vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) - vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout.shape[-1] + vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1] except Exception as e: if(issubclass(type(e), lupa.LuaError)):