diff --git a/aiserver.py b/aiserver.py index ad6e7fd6..94aa3b60 100644 --- a/aiserver.py +++ b/aiserver.py @@ -238,6 +238,9 @@ class vars: tfs = 1.0 # Default generator tfs (tail-free sampling) typical = 1.0 # Default generator typical sampling threshold numseqs = 1 # Number of sequences to ask the generator to create + full_determinism = False # Whether or not full determinism is enabled + seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file) + seed = None # The current RNG seed (as an int), or None if unknown gamestarted = False # Whether the game has started (disables UI elements) gamesaved = True # Whether or not current game is saved serverstarted = False # Whether or not the Flask server has started @@ -785,8 +788,13 @@ def savesettings(): js["nopromptgen"] = vars.nopromptgen js["rngpersist"] = vars.rngpersist js["nogenmod"] = vars.nogenmod + js["fulldeterminism"] = vars.full_determinism js["autosave"] = vars.autosave js["welcome"] = vars.welcome + + if(vars.seed_specified): + js["seed"] = vars.seed + js["newlinemode"] = vars.newlinemode js["antemplate"] = vars.setauthornotetemplate @@ -886,12 +894,20 @@ def processsettings(js): vars.rngpersist = js["rngpersist"] if("nogenmod" in js): vars.nogenmod = js["nogenmod"] + if("fulldeterminism" in js): + vars.full_determinism = js["fulldeterminism"] if("autosave" in js): vars.autosave = js["autosave"] if("newlinemode" in js): vars.newlinemode = js["newlinemode"] if("welcome" in js): vars.welcome = js["welcome"] + + if("seed" in js): + vars.seed = js["seed"] + vars.seed_specified = True + else: + vars.seed_specified = False if("antemplate" in js): vars.setauthornotetemplate = js["antemplate"] @@ -3383,6 +3399,10 @@ def get_message(msg): vars.nogenmod = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setfulldeterminism'): + vars.full_determinism = msg['data'] + settingschanged() + refresh_settings() elif(not vars.host and msg['cmd'] == 'importwi'): wiimportrequest() elif(msg['cmd'] == 'debug'): @@ -3942,6 +3962,9 @@ def calcsubmit(txt): #==================================================================# def _generate(txt, minimum, maximum, found_entries): + if(vars.full_determinism): + torch.manual_seed(vars.seed) + gen_in = torch.tensor(txt, dtype=torch.long)[None] if(vars.sp is not None): soft_tokens = torch.arange( @@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max): # Send text to TPU mesh transformer backend #==================================================================# def tpumtjgenerate(txt, minimum, maximum, found_entries=None): + if(vars.full_determinism): + tpu_mtj_backend.set_rng_seed(vars.seed) + vars.generated_tkns = 0 if(found_entries is None): @@ -4567,6 +4593,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True) + emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) @@ -5987,9 +6014,27 @@ def final_startup(): }, ).start() + # Set the initial RNG seed + if(vars.seed is not None): + if(vars.use_colab_tpu): + if(vars.seed_specified): + __import__("tpu_mtj_backend").set_rng_seed(vars.seed) + else: + __import__("tpu_mtj_backend").randomize_rng_seed() + else: + if(vars.seed_specified): + __import__("torch").manual_seed(vars.seed) + else: + __import__("torch").seed() + vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed() + def send_debug(): if vars.debug: debug_info = "" + try: + debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated") + except: + pass try: debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode) except: diff --git a/gensettings.py b/gensettings.py index b3007c91..3d188b16 100644 --- a/gensettings.py +++ b/gensettings.py @@ -230,6 +230,17 @@ gensettingstf = [ "default": 0, "tooltip": "Disables userscript generation modifiers." }, + { + "uitype": "toggle", + "unit": "bool", + "label": "Full Determinism", + "id": "setfulldeterminism", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic." + }, { "uitype": "toggle", "unit": "bool", diff --git a/static/application.js b/static/application.js index 20a1860a..7da45e9f 100644 --- a/static/application.js +++ b/static/application.js @@ -2604,6 +2604,9 @@ $(document).ready(function(){ } else if(msg.cmd == "updatenogenmod") { // Update toggle state $("#setnogenmod").prop('checked', msg.data).change(); + } else if(msg.cmd == "updatefulldeterminism") { + // Update toggle state + $("#setfulldeterminism").prop('checked', msg.data).change(); } else if(msg.cmd == "runs_remotely") { remote = true; hide([button_savetofile, button_import, button_importwi]); diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index bc168f36..8daa8dee 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16 params: Dict[str, Any] = {} +__seed = random.randrange(sys.maxsize) +rng = random.Random(__seed) + + +def get_rng_seed(): + return __seed + +def set_rng_seed(seed: int): + global __seed, rng + rng = random.Random(seed) + __seed = seed + return seed + +def randomize_rng_seed(): + return set_rng_seed(random.randrange(sys.maxsize)) + def warper_callback(logits) -> np.array: raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") @@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer): assert not return_logits assert gen_length.ndim == 1 assert soft_embeddings is not None - key = hk.PRNGSequence(random.randint(0, 2 ** 60)) + key = hk.PRNGSequence(rng.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) @@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer): return sample_data, n_generated, regeneration_required, halt def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): assert not return_logits - key = hk.PRNGSequence(random.randint(0, 2 ** 60)) + key = hk.PRNGSequence(rng.randint(0, 2 ** 60)) batch_size = ctx.shape[0] self.batch_size = batch_size started_compiling_callback()