mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Fix TPU generation modifier
This commit is contained in:
@ -20,9 +20,12 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def warper_callback(generated, logits, excluded_world_info, n_generated) -> Tuple[bool, bool]:
|
||||
def warper_callback(logits) -> np.array:
|
||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||
|
||||
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
||||
|
||||
|
||||
def show_spinner():
|
||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||
@ -340,12 +343,18 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
||||
for i in range(numseqs):
|
||||
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
|
||||
if use_callback:
|
||||
logits = np.float32(tuple(d[2] for d in sample_data))
|
||||
logits = warper_callback(logits)
|
||||
for i in range(numseqs):
|
||||
sample_data[i][2] = logits[i]
|
||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||
n_generated += 1
|
||||
for i in range(numseqs):
|
||||
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||
n_generated += 1
|
||||
if use_callback:
|
||||
excluded_world_info, regeneration_required, halt = warper_callback(np.uint32(tuple(d[0] for d in sample_data)), np.float32(tuple(d[2] for d in sample_data)), excluded_world_info, n_generated)
|
||||
generated = np.uint32(tuple(d[0] for d in sample_data))
|
||||
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
|
||||
if regeneration_required or halt:
|
||||
break
|
||||
else:
|
||||
|
Reference in New Issue
Block a user