mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add support for setting the RNG seed and full determinism
This commit is contained in:
45
aiserver.py
45
aiserver.py
@ -238,6 +238,9 @@ class vars:
|
||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||
typical = 1.0 # Default generator typical sampling threshold
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
full_determinism = False # Whether or not full determinism is enabled
|
||||
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
|
||||
seed = None # The current RNG seed (as an int), or None if unknown
|
||||
gamestarted = False # Whether the game has started (disables UI elements)
|
||||
gamesaved = True # Whether or not current game is saved
|
||||
serverstarted = False # Whether or not the Flask server has started
|
||||
@ -785,8 +788,13 @@ def savesettings():
|
||||
js["nopromptgen"] = vars.nopromptgen
|
||||
js["rngpersist"] = vars.rngpersist
|
||||
js["nogenmod"] = vars.nogenmod
|
||||
js["fulldeterminism"] = vars.full_determinism
|
||||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
|
||||
if(vars.seed_specified):
|
||||
js["seed"] = vars.seed
|
||||
|
||||
js["newlinemode"] = vars.newlinemode
|
||||
|
||||
js["antemplate"] = vars.setauthornotetemplate
|
||||
@ -886,12 +894,20 @@ def processsettings(js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("fulldeterminism" in js):
|
||||
vars.full_determinism = js["fulldeterminism"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("seed" in js):
|
||||
vars.seed = js["seed"]
|
||||
vars.seed_specified = True
|
||||
else:
|
||||
vars.seed_specified = False
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
@ -3383,6 +3399,10 @@ def get_message(msg):
|
||||
vars.nogenmod = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setfulldeterminism'):
|
||||
vars.full_determinism = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(not vars.host and msg['cmd'] == 'importwi'):
|
||||
wiimportrequest()
|
||||
elif(msg['cmd'] == 'debug'):
|
||||
@ -3942,6 +3962,9 @@ def calcsubmit(txt):
|
||||
#==================================================================#
|
||||
|
||||
def _generate(txt, minimum, maximum, found_entries):
|
||||
if(vars.full_determinism):
|
||||
torch.manual_seed(vars.seed)
|
||||
|
||||
gen_in = torch.tensor(txt, dtype=torch.long)[None]
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max):
|
||||
# Send text to TPU mesh transformer backend
|
||||
#==================================================================#
|
||||
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
if(vars.full_determinism):
|
||||
tpu_mtj_backend.set_rng_seed(vars.seed)
|
||||
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
@ -4567,6 +4593,7 @@ def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
|
||||
|
||||
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
|
||||
@ -5987,9 +6014,27 @@ def final_startup():
|
||||
},
|
||||
).start()
|
||||
|
||||
# Set the initial RNG seed
|
||||
if(vars.seed is not None):
|
||||
if(vars.use_colab_tpu):
|
||||
if(vars.seed_specified):
|
||||
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
|
||||
else:
|
||||
__import__("tpu_mtj_backend").randomize_rng_seed()
|
||||
else:
|
||||
if(vars.seed_specified):
|
||||
__import__("torch").manual_seed(vars.seed)
|
||||
else:
|
||||
__import__("torch").seed()
|
||||
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
|
||||
|
||||
def send_debug():
|
||||
if vars.debug:
|
||||
debug_info = ""
|
||||
try:
|
||||
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
|
||||
except:
|
||||
|
Reference in New Issue
Block a user