From 048bd0ff3b9b1f145c3d2be9e77fa9cebfc6edfb Mon Sep 17 00:00:00 2001 From: vfbd Date: Tue, 28 Jun 2022 13:21:05 -0400 Subject: [PATCH] Add support for setting the RNG seed and full determinism --- aiserver.py | 45 +++++++++++++++++++++++++++++++++++++++++++ gensettings.py | 11 +++++++++++ static/application.js | 3 +++ tpu_mtj_backend.py | 20 +++++++++++++++++-- 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/aiserver.py b/aiserver.py index ad6e7fd6..94aa3b60 100644 --- a/aiserver.py +++ b/aiserver.py @@ -238,6 +238,9 @@ class vars: tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create + full_determinism = False # Whether or not full determinism is enabled + seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file) + seed = None # The current RNG seed (as an int), or None if unknown gamestarted = False # Whether the game has started (disables UI elements) gamesaved = True # Whether or not current game is saved serverstarted = False # Whether or not the Flask server has started @@ -785,8 +788,13 @@ def savesettings(): js["nopromptgen"] = vars.nopromptgen js["rngpersist"] = vars.rngpersist js["nogenmod"] = vars.nogenmod + js["fulldeterminism"] = vars.full_determinism js["autosave"] = vars.autosave js["welcome"] = vars.welcome + + if(vars.seed_specified): + js["seed"] = vars.seed + js["newlinemode"] = vars.newlinemode js["antemplate"] = vars.setauthornotetemplate @@ -886,12 +894,20 @@ def processsettings(js): vars.rngpersist = js["rngpersist"] if("nogenmod" in js): vars.nogenmod = js["nogenmod"] + if("fulldeterminism" in js): + vars.full_determinism = js["fulldeterminism"] if("autosave" in js): vars.autosave = js["autosave"] if("newlinemode" in js): vars.newlinemode = js["newlinemode"] if("welcome" in js): vars.welcome = js["welcome"] + + if("seed" in js): + vars.seed = js["seed"] + vars.seed_specified = True + else: + vars.seed_specified = False if("antemplate" in js): vars.setauthornotetemplate = js["antemplate"] @@ -3383,6 +3399,10 @@ def get_message(msg): vars.nogenmod = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setfulldeterminism'): + vars.full_determinism = msg['data'] + settingschanged() + refresh_settings() elif(not vars.host and msg['cmd'] == 'importwi'): wiimportrequest() elif(msg['cmd'] == 'debug'): @@ -3942,6 +3962,9 @@ def calcsubmit(txt): #==================================================================# def _generate(txt, minimum, maximum, found_entries): + if(vars.full_determinism): + torch.manual_seed(vars.seed) + gen_in = torch.tensor(txt, dtype=torch.long)[None] if(vars.sp is not None): soft_tokens = torch.arange( @@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max): # Send text to TPU mesh transformer backend #==================================================================# def tpumtjgenerate(txt, minimum, maximum, found_entries=None): + if(vars.full_determinism): + tpu_mtj_backend.set_rng_seed(vars.seed) + vars.generated_tkns = 0 if(found_entries is None): @@ -4567,6 +4593,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True) + emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) @@ -5987,9 +6014,27 @@ def final_startup(): }, ).start() + # Set the initial RNG seed + if(vars.seed is not None): + if(vars.use_colab_tpu): + if(vars.seed_specified): + __import__("tpu_mtj_backend").set_rng_seed(vars.seed) + else: + __import__("tpu_mtj_backend").randomize_rng_seed() + else: + if(vars.seed_specified): + __import__("torch").manual_seed(vars.seed) + else: + __import__("torch").seed() + vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed() + def send_debug(): if vars.debug: debug_info = "" + try: + debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated") + except: + pass try: debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode) except: diff --git a/gensettings.py b/gensettings.py index b3007c91..3d188b16 100644 --- a/gensettings.py +++ b/gensettings.py @@ -230,6 +230,17 @@ gensettingstf = [ "default": 0, "tooltip": "Disables userscript generation modifiers." }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Full Determinism", + "id": "setfulldeterminism", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic." + }, { "uitype": "toggle", "unit": "bool", diff --git a/static/application.js b/static/application.js index a7a10016..d91f6029 100644 --- a/static/application.js +++ b/static/application.js @@ -2606,6 +2606,9 @@ $(document).ready(function(){ } else if(msg.cmd == "updatenogenmod") { // Update toggle state $("#setnogenmod").prop('checked', msg.data).change(); + } else if(msg.cmd == "updatefulldeterminism") { + // Update toggle state + $("#setfulldeterminism").prop('checked', msg.data).change(); } else if(msg.cmd == "runs_remotely") { remote = true; hide([button_savetofile, button_import, button_importwi]); diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index bc168f36..8daa8dee 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16 params: Dict[str, Any] = {} +__seed = random.randrange(sys.maxsize) +rng = random.Random(__seed) + + +def get_rng_seed(): + return __seed + +def set_rng_seed(seed: int): + global __seed, rng + rng = random.Random(seed) + __seed = seed + return seed + +def randomize_rng_seed(): + return set_rng_seed(random.randrange(sys.maxsize)) + def warper_callback(logits) -> np.array: raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") @@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer): assert not return_logits assert gen_length.ndim == 1 assert soft_embeddings is not None - key = hk.PRNGSequence(random.randint(0, 2 ** 60)) + key = hk.PRNGSequence(rng.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) @@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer): return sample_data, n_generated, regeneration_required, halt def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits - key = hk.PRNGSequence(random.randint(0, 2 ** 60)) + key = hk.PRNGSequence(rng.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size started_compiling_callback()