Add support for setting the RNG seed and full determinism

This commit is contained in:
vfbd
2022-06-28 13:21:05 -04:00
parent 496f6dcf3f
commit 048bd0ff3b
4 changed files with 77 additions and 2 deletions

View File

@ -238,6 +238,9 @@ class vars:
tfs = 1.0 # Default generator tfs (tail-free sampling)
typical = 1.0 # Default generator typical sampling threshold
numseqs = 1 # Number of sequences to ask the generator to create
full_determinism = False # Whether or not full determinism is enabled
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
seed = None # The current RNG seed (as an int), or None if unknown
gamestarted = False # Whether the game has started (disables UI elements)
gamesaved = True # Whether or not current game is saved
serverstarted = False # Whether or not the Flask server has started
@ -785,8 +788,13 @@ def savesettings():
js["nopromptgen"] = vars.nopromptgen
js["rngpersist"] = vars.rngpersist
js["nogenmod"] = vars.nogenmod
js["fulldeterminism"] = vars.full_determinism
js["autosave"] = vars.autosave
js["welcome"] = vars.welcome
if(vars.seed_specified):
js["seed"] = vars.seed
js["newlinemode"] = vars.newlinemode
js["antemplate"] = vars.setauthornotetemplate
@ -886,12 +894,20 @@ def processsettings(js):
vars.rngpersist = js["rngpersist"]
if("nogenmod" in js):
vars.nogenmod = js["nogenmod"]
if("fulldeterminism" in js):
vars.full_determinism = js["fulldeterminism"]
if("autosave" in js):
vars.autosave = js["autosave"]
if("newlinemode" in js):
vars.newlinemode = js["newlinemode"]
if("welcome" in js):
vars.welcome = js["welcome"]
if("seed" in js):
vars.seed = js["seed"]
vars.seed_specified = True
else:
vars.seed_specified = False
if("antemplate" in js):
vars.setauthornotetemplate = js["antemplate"]
@ -3383,6 +3399,10 @@ def get_message(msg):
vars.nogenmod = msg['data']
settingschanged()
refresh_settings()
elif(msg['cmd'] == 'setfulldeterminism'):
vars.full_determinism = msg['data']
settingschanged()
refresh_settings()
elif(not vars.host and msg['cmd'] == 'importwi'):
wiimportrequest()
elif(msg['cmd'] == 'debug'):
@ -3942,6 +3962,9 @@ def calcsubmit(txt):
#==================================================================#
def _generate(txt, minimum, maximum, found_entries):
if(vars.full_determinism):
torch.manual_seed(vars.seed)
gen_in = torch.tensor(txt, dtype=torch.long)[None]
if(vars.sp is not None):
soft_tokens = torch.arange(
@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max):
# Send text to TPU mesh transformer backend
#==================================================================#
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
if(vars.full_determinism):
tpu_mtj_backend.set_rng_seed(vars.seed)
vars.generated_tkns = 0
if(found_entries is None):
@ -4567,6 +4593,7 @@ def refresh_settings():
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
@ -5987,9 +6014,27 @@ def final_startup():
},
).start()
# Set the initial RNG seed
if(vars.seed is not None):
if(vars.use_colab_tpu):
if(vars.seed_specified):
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
else:
__import__("tpu_mtj_backend").randomize_rng_seed()
else:
if(vars.seed_specified):
__import__("torch").manual_seed(vars.seed)
else:
__import__("torch").seed()
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
def send_debug():
if vars.debug:
debug_info = ""
try:
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
except:
pass
try:
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
except: