Add support for setting the RNG seed and full determinism

This commit is contained in:
vfbd
2022-06-28 13:21:05 -04:00
parent 496f6dcf3f
commit 048bd0ff3b
4 changed files with 77 additions and 2 deletions

View File

@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {}
__seed = random.randrange(sys.maxsize)
rng = random.Random(__seed)
def get_rng_seed():
return __seed
def set_rng_seed(seed: int):
global __seed, rng
rng = random.Random(seed)
__seed = seed
return seed
def randomize_rng_seed():
return set_rng_seed(random.randrange(sys.maxsize))
def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer):
assert not return_logits
assert gen_length.ndim == 1
assert soft_embeddings is not None
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer):
return sample_data, n_generated, regeneration_required, halt
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
assert not return_logits
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
started_compiling_callback()