Add support for setting the RNG seed and full determinism

This commit is contained in:
vfbd 2022-06-28 13:21:05 -04:00
parent 496f6dcf3f
commit 048bd0ff3b
4 changed files with 77 additions and 2 deletions

View File

@ -238,6 +238,9 @@ class vars:
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
full_determinism = False # Whether or not full determinism is enabled
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
seed = None # The current RNG seed (as an int), or None if unknown
gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved
serverstarted = False # Whether or not the Flask server has started
@ -785,8 +788,13 @@ def savesettings():
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["fulldeterminism"] = vars.full_determinism
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
if(vars.seed_specified):
js["seed"] = vars.seed
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
@ -886,12 +894,20 @@ def processsettings(js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("fulldeterminism" in js):
vars.full_determinism = js["fulldeterminism"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("seed" in js):
vars.seed = js["seed"]
vars.seed_specified = True
else:
vars.seed_specified = False
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
@ -3383,6 +3399,10 @@ def get_message(msg):
vars.nogenmod = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setfulldeterminism'):
vars.full_determinism = msg['data']
settingschanged()
refresh_settings()
elif(not vars.host and msg['cmd'] == 'importwi'):
wiimportrequest()
elif(msg['cmd'] == 'debug'):
@ -3942,6 +3962,9 @@ def calcsubmit(txt):
#==================================================================#
def _generate(txt, minimum, maximum, found_entries):
if(vars.full_determinism):
torch.manual_seed(vars.seed)
gen_in = torch.tensor(txt, dtype=torch.long)[None]
if(vars.sp is not None):
soft_tokens = torch.arange(
@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(vars.full_determinism):
tpu_mtj_backend.set_rng_seed(vars.seed)
vars.generated_tkns = 0
if(found_entries is None):
@ -4567,6 +4593,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -5987,9 +6014,27 @@ def final_startup():
},
).start()
# Set the initial RNG seed
if(vars.seed is not None):
if(vars.use_colab_tpu):
if(vars.seed_specified):
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
else:
__import__("tpu_mtj_backend").randomize_rng_seed()
else:
if(vars.seed_specified):
__import__("torch").manual_seed(vars.seed)
else:
__import__("torch").seed()
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
def send_debug():
if vars.debug:
debug_info = ""
try:
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
except:
pass
try:
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
except:

View File

@ -230,6 +230,17 @@ gensettingstf = [
"default": 0,
"tooltip": "Disables userscript generation modifiers."
},
{
"uitype": "toggle",
"unit": "bool",
"label": "Full Determinism",
"id": "setfulldeterminism",
"min": 0,
"max": 1,
"step": 1,
"default": 0,
"tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic."
},
{
"uitype": "toggle",
"unit": "bool",

View File

@ -2606,6 +2606,9 @@ $(document).ready(function(){
} else if(msg.cmd == "updatenogenmod") {
// Update toggle state
$("#setnogenmod").prop('checked', msg.data).change();
} else if(msg.cmd == "updatefulldeterminism") {
// Update toggle state
$("#setfulldeterminism").prop('checked', msg.data).change();
} else if(msg.cmd == "runs_remotely") {
remote = true;
hide([button_savetofile, button_import, button_importwi]);

View File

@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16
params: Dict[str, Any] = {}
__seed = random.randrange(sys.maxsize)
rng = random.Random(__seed)
def get_rng_seed():
return __seed
def set_rng_seed(seed: int):
global __seed, rng
rng = random.Random(seed)
__seed = seed
return seed
def randomize_rng_seed():
return set_rng_seed(random.randrange(sys.maxsize))
def warper_callback(logits) -> np.array:
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer):
assert not return_logits
assert gen_length.ndim == 1
assert soft_embeddings is not None
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer):
return sample_data, n_generated, regeneration_required, halt
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
assert not return_logits
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
batch_size = ctx.shape[0]
self.batch_size = batch_size
started_compiling_callback()