mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix TPU generation modifier
This commit is contained in:
30
aiserver.py
30
aiserver.py
@ -1001,19 +1001,7 @@ else:
|
||||
)
|
||||
return soft_tokens
|
||||
|
||||
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
|
||||
vars.generated_tkns += 1
|
||||
|
||||
assert len(excluded_world_info) == len(generated)
|
||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
global past
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||
|
||||
def tpumtjgenerate_warper_callback(scores) -> "np.array":
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||
@ -1029,6 +1017,21 @@ else:
|
||||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||
vars.generated_tkns += 1
|
||||
|
||||
assert len(excluded_world_info) == len(generated)
|
||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
global past
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||
|
||||
if(not vars.dynamicscan or halt):
|
||||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
@ -1054,6 +1057,7 @@ else:
|
||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
|
Reference in New Issue
Block a user