Add support for setting the RNG seed and full determinism

This commit is contained in:
vfbd 2022-06-28 13:21:05 -04:00
parent 496f6dcf3f
commit 048bd0ff3b
4 changed files with 77 additions and 2 deletions

View File

@ -238,6 +238,9 @@ class vars:
tfs = 1.0 # Default generator tfs (tail-free sampling) tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create numseqs = 1 # Number of sequences to ask the generator to create
full_determinism = False # Whether or not full determinism is enabled
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
seed = None # The current RNG seed (as an int), or None if unknown
gamestarted = False # Whether the game has started (disables UI elements) gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved gamesaved = True # Whether or not current game is saved
serverstarted = False # Whether or not the Flask server has started serverstarted = False # Whether or not the Flask server has started
@ -785,8 +788,13 @@ def savesettings():
js["nopromptgen"] = vars.nopromptgen js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod js["nogenmod"] = vars.nogenmod
js["fulldeterminism"] = vars.full_determinism
js["autosave"] = vars.autosave js["autosave"] = vars.autosave
js["welcome"] = vars.welcome js["welcome"] = vars.welcome
if(vars.seed_specified):
js["seed"] = vars.seed
js["newlinemode"] = vars.newlinemode js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate js["antemplate"] = vars.setauthornotetemplate
@ -886,12 +894,20 @@ def processsettings(js):
vars.rngpersist = js["rngpersist"] vars.rngpersist = js["rngpersist"]
if("nogenmod" in js): if("nogenmod" in js):
vars.nogenmod = js["nogenmod"] vars.nogenmod = js["nogenmod"]
if("fulldeterminism" in js):
vars.full_determinism = js["fulldeterminism"]
if("autosave" in js): if("autosave" in js):
vars.autosave = js["autosave"] vars.autosave = js["autosave"]
if("newlinemode" in js): if("newlinemode" in js):
vars.newlinemode = js["newlinemode"] vars.newlinemode = js["newlinemode"]
if("welcome" in js): if("welcome" in js):
vars.welcome = js["welcome"] vars.welcome = js["welcome"]
if("seed" in js):
vars.seed = js["seed"]
vars.seed_specified = True
else:
vars.seed_specified = False
if("antemplate" in js): if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"] vars.setauthornotetemplate = js["antemplate"]
@ -3383,6 +3399,10 @@ def get_message(msg):
vars.nogenmod = msg['data'] vars.nogenmod = msg['data']
settingschanged() settingschanged()
refresh_settings() refresh_settings()
elif(msg['cmd'] == 'setfulldeterminism'):
vars.full_determinism = msg['data']
settingschanged()
refresh_settings()
elif(not vars.host and msg['cmd'] == 'importwi'): elif(not vars.host and msg['cmd'] == 'importwi'):
wiimportrequest() wiimportrequest()
elif(msg['cmd'] == 'debug'): elif(msg['cmd'] == 'debug'):
@ -3942,6 +3962,9 @@ def calcsubmit(txt):
#==================================================================# #==================================================================#
def _generate(txt, minimum, maximum, found_entries): def _generate(txt, minimum, maximum, found_entries):
if(vars.full_determinism):
torch.manual_seed(vars.seed)
gen_in = torch.tensor(txt, dtype=torch.long)[None] gen_in = torch.tensor(txt, dtype=torch.long)[None]
if(vars.sp is not None): if(vars.sp is not None):
soft_tokens = torch.arange( soft_tokens = torch.arange(
@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend # Send text to TPU mesh transformer backend
#==================================================================# #==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None): def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(vars.full_determinism):
tpu_mtj_backend.set_rng_seed(vars.seed)
vars.generated_tkns = 0 vars.generated_tkns = 0
if(found_entries is None): if(found_entries is None):
@ -4567,6 +4593,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True) emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -5987,9 +6014,27 @@ def final_startup():
}, },
).start() ).start()
# Set the initial RNG seed
if(vars.seed is not None):
if(vars.use_colab_tpu):
if(vars.seed_specified):
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
else:
__import__("tpu_mtj_backend").randomize_rng_seed()
else:
if(vars.seed_specified):
__import__("torch").manual_seed(vars.seed)
else:
__import__("torch").seed()
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
def send_debug(): def send_debug():
if vars.debug: if vars.debug:
debug_info = "" debug_info = ""
try:
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
except:
pass
try: try:
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode) debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
except: except:

View File

@ -230,6 +230,17 @@ gensettingstf = [
"default": 0, "default": 0,
"tooltip": "Disables userscript generation modifiers." "tooltip": "Disables userscript generation modifiers."
}, },
{
"uitype": "toggle",
"unit": "bool",
"label": "Full Determinism",
"id": "setfulldeterminism",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic."
},
{ {
"uitype": "toggle", "uitype": "toggle",
"unit": "bool", "unit": "bool",

View File

@ -2606,6 +2606,9 @@ $(document).ready(function(){
} else if(msg.cmd == "updatenogenmod") { } else if(msg.cmd == "updatenogenmod") {
// Update toggle state // Update toggle state
$("#setnogenmod").prop('checked', msg.data).change(); $("#setnogenmod").prop('checked', msg.data).change();
} else if(msg.cmd == "updatefulldeterminism") {
// Update toggle state
$("#setfulldeterminism").prop('checked', msg.data).change();
} else if(msg.cmd == "runs_remotely") { } else if(msg.cmd == "runs_remotely") {
remote = true; remote = true;
hide([button_savetofile, button_import, button_importwi]); hide([button_savetofile, button_import, button_importwi]);

View File

@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
__seed = random.randrange(sys.maxsize)
rng = random.Random(__seed)
def get_rng_seed():
return __seed
def set_rng_seed(seed: int):
global __seed, rng
rng = random.Random(seed)
__seed = seed
return seed
def randomize_rng_seed():
return set_rng_seed(random.randrange(sys.maxsize))
def warper_callback(logits) -> np.array: def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer):
assert not return_logits assert not return_logits
assert gen_length.ndim == 1 assert gen_length.ndim == 1
assert soft_embeddings is not None assert soft_embeddings is not None
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
self.batch_size = batch_size self.batch_size = batch_size
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer):
return sample_data, n_generated, regeneration_required, halt return sample_data, n_generated, regeneration_required, halt
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
assert not return_logits assert not return_logits
key = hk.PRNGSequence(random.randint(0, 2 ** 60)) key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0] batch_size = ctx.shape[0]
self.batch_size = batch_size self.batch_size = batch_size
started_compiling_callback() started_compiling_callback()