mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add support for setting the RNG seed and full determinism
This commit is contained in:
@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
__seed = random.randrange(sys.maxsize)
|
||||
rng = random.Random(__seed)
|
||||
|
||||
|
||||
def get_rng_seed():
|
||||
return __seed
|
||||
|
||||
def set_rng_seed(seed: int):
|
||||
global __seed, rng
|
||||
rng = random.Random(seed)
|
||||
__seed = seed
|
||||
return seed
|
||||
|
||||
def randomize_rng_seed():
|
||||
return set_rng_seed(random.randrange(sys.maxsize))
|
||||
|
||||
|
||||
def warper_callback(logits) -> np.array:
|
||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||
@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
assert not return_logits
|
||||
assert gen_length.ndim == 1
|
||||
assert soft_embeddings is not None
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
|
||||
@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
return sample_data, n_generated, regeneration_required, halt
|
||||
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
||||
assert not return_logits
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
started_compiling_callback()
|
||||
|
Reference in New Issue
Block a user