Fix settings callback, and `genout.shape[-1]` in `tpumtjgenerate()`
This commit is contained in:
parent
293b75e89f
commit
703c092577
|
@ -1087,6 +1087,7 @@ else:
|
||||||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||||
|
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
|
@ -3068,7 +3069,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
past = genout
|
past = genout
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
||||||
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout.shape[-1]
|
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if(issubclass(type(e), lupa.LuaError)):
|
if(issubclass(type(e), lupa.LuaError)):
|
||||||
|
|
Loading…
Reference in New Issue