mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-25 16:08:00 +01:00
Fix TPU generation modifier
This commit is contained in:
parent
932c393d6a
commit
e0fdce2cc6
30
aiserver.py
30
aiserver.py
@ -1001,19 +1001,7 @@ else:
|
|||||||
)
|
)
|
||||||
return soft_tokens
|
return soft_tokens
|
||||||
|
|
||||||
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
|
def tpumtjgenerate_warper_callback(scores) -> "np.array":
|
||||||
vars.generated_tkns += 1
|
|
||||||
|
|
||||||
assert len(excluded_world_info) == len(generated)
|
|
||||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
|
||||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
|
||||||
|
|
||||||
global past
|
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
|
||||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
|
||||||
|
|
||||||
scores_shape = scores.shape
|
scores_shape = scores.shape
|
||||||
scores_list = scores.tolist()
|
scores_list = scores.tolist()
|
||||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||||
@ -1029,6 +1017,21 @@ else:
|
|||||||
)
|
)
|
||||||
assert scores.shape == scores_shape
|
assert scores.shape == scores_shape
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||||
|
vars.generated_tkns += 1
|
||||||
|
|
||||||
|
assert len(excluded_world_info) == len(generated)
|
||||||
|
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||||
|
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
global past
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||||
|
|
||||||
if(not vars.dynamicscan or halt):
|
if(not vars.dynamicscan or halt):
|
||||||
return excluded_world_info, regeneration_required, halt
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
@ -1054,6 +1057,7 @@ else:
|
|||||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||||
|
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
|
@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
|||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]:
|
def warper_callback(logits) -> np.array:
|
||||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||||
|
|
||||||
|
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||||
|
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
||||||
|
|
||||||
|
|
||||||
def show_spinner():
|
def show_spinner():
|
||||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
||||||
for i in range(numseqs):
|
for i in range(numseqs):
|
||||||
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
|
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
|
||||||
|
if use_callback:
|
||||||
|
logits = np.float32(tuple(d[2] for d in sample_data))
|
||||||
|
logits = warper_callback(logits)
|
||||||
|
for i in range(numseqs):
|
||||||
|
sample_data[i][2] = logits[i]
|
||||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||||
|
n_generated += 1
|
||||||
for i in range(numseqs):
|
for i in range(numseqs):
|
||||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||||
n_generated += 1
|
|
||||||
if use_callback:
|
if use_callback:
|
||||||
excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated)
|
generated = np.uint32(tuple(d[0] for d in sample_data))
|
||||||
|
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
|
||||||
if regeneration_required or halt:
|
if regeneration_required or halt:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user