Fix settings callback, and `genout.shape[-1]` in `tpumtjgenerate()`
This commit is contained in:
parent
293b75e89f
commit
703c092577
|
@ -1087,6 +1087,7 @@ else:
|
|||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
|
@ -3068,7 +3069,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||
past = genout
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
||||
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout.shape[-1]
|
||||
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1]
|
||||
|
||||
except Exception as e:
|
||||
if(issubclass(type(e), lupa.LuaError)):
|
||||
|
|
Loading…
Reference in New Issue