Fix TPU generation modifier
This commit is contained in:
parent
932c393d6a
commit
e0fdce2cc6
30
aiserver.py
30
aiserver.py
|
@ -1001,19 +1001,7 @@ else:
|
|||
)
|
||||
return soft_tokens
|
||||
|
||||
def tpumtjgenerate_warper_callback(generated, scores, excluded_world_info, n_generated) -> Tuple[List[set], bool, bool]:
|
||||
vars.generated_tkns += 1
|
||||
|
||||
assert len(excluded_world_info) == len(generated)
|
||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
global past
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||
|
||||
def tpumtjgenerate_warper_callback(scores) -> "np.array":
|
||||
scores_shape = scores.shape
|
||||
scores_list = scores.tolist()
|
||||
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||
|
@ -1029,6 +1017,21 @@ else:
|
|||
)
|
||||
assert scores.shape == scores_shape
|
||||
|
||||
return scores
|
||||
|
||||
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||
vars.generated_tkns += 1
|
||||
|
||||
assert len(excluded_world_info) == len(generated)
|
||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||
vars.lua_koboldbridge.regeneration_required = False
|
||||
|
||||
global past
|
||||
|
||||
for i in range(vars.numseqs):
|
||||
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||
|
||||
if(not vars.dynamicscan or halt):
|
||||
return excluded_world_info, regeneration_required, halt
|
||||
|
||||
|
@ -1054,6 +1057,7 @@ else:
|
|||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||
import tpu_mtj_backend
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||
vars.allowsp = True
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
|
|
|
@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
|||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]:
|
||||
def warper_callback(logits) -> np.array:
|
||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||
|
||||
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
||||
|
||||
|
||||
def show_spinner():
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
|
@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
||||
for i in range(numseqs):
|
||||
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
|
||||
if use_callback:
|
||||
logits = np.float32(tuple(d[2] for d in sample_data))
|
||||
logits = warper_callback(logits)
|
||||
for i in range(numseqs):
|
||||
sample_data[i][2] = logits[i]
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||
n_generated += 1
|
||||
for i in range(numseqs):
|
||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||
n_generated += 1
|
||||
if use_callback:
|
||||
excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated)
|
||||
generated = np.uint32(tuple(d[0] for d in sample_data))
|
||||
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
|
||||
if regeneration_required or halt:
|
||||
break
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue