mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-21 14:10:39 +01:00
Add support for setting the RNG seed and full determinism
This commit is contained in:
parent
496f6dcf3f
commit
048bd0ff3b
45
aiserver.py
45
aiserver.py
@ -238,6 +238,9 @@ class vars:
|
||||
tfs = 1.0 # Default generator tfs (tail-free sampling)
|
||||
typical = 1.0 # Default generator typical sampling threshold
|
||||
numseqs = 1 # Number of sequences to ask the generator to create
|
||||
full_determinism = False # Whether or not full determinism is enabled
|
||||
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
|
||||
seed = None # The current RNG seed (as an int), or None if unknown
|
||||
gamestarted = False # Whether the game has started (disables UI elements)
|
||||
gamesaved = True # Whether or not current game is saved
|
||||
serverstarted = False # Whether or not the Flask server has started
|
||||
@ -785,8 +788,13 @@ def savesettings():
|
||||
js["nopromptgen"] = vars.nopromptgen
|
||||
js["rngpersist"] = vars.rngpersist
|
||||
js["nogenmod"] = vars.nogenmod
|
||||
js["fulldeterminism"] = vars.full_determinism
|
||||
js["autosave"] = vars.autosave
|
||||
js["welcome"] = vars.welcome
|
||||
|
||||
if(vars.seed_specified):
|
||||
js["seed"] = vars.seed
|
||||
|
||||
js["newlinemode"] = vars.newlinemode
|
||||
|
||||
js["antemplate"] = vars.setauthornotetemplate
|
||||
@ -886,12 +894,20 @@ def processsettings(js):
|
||||
vars.rngpersist = js["rngpersist"]
|
||||
if("nogenmod" in js):
|
||||
vars.nogenmod = js["nogenmod"]
|
||||
if("fulldeterminism" in js):
|
||||
vars.full_determinism = js["fulldeterminism"]
|
||||
if("autosave" in js):
|
||||
vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
vars.newlinemode = js["newlinemode"]
|
||||
if("welcome" in js):
|
||||
vars.welcome = js["welcome"]
|
||||
|
||||
if("seed" in js):
|
||||
vars.seed = js["seed"]
|
||||
vars.seed_specified = True
|
||||
else:
|
||||
vars.seed_specified = False
|
||||
|
||||
if("antemplate" in js):
|
||||
vars.setauthornotetemplate = js["antemplate"]
|
||||
@ -3383,6 +3399,10 @@ def get_message(msg):
|
||||
vars.nogenmod = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(msg['cmd'] == 'setfulldeterminism'):
|
||||
vars.full_determinism = msg['data']
|
||||
settingschanged()
|
||||
refresh_settings()
|
||||
elif(not vars.host and msg['cmd'] == 'importwi'):
|
||||
wiimportrequest()
|
||||
elif(msg['cmd'] == 'debug'):
|
||||
@ -3942,6 +3962,9 @@ def calcsubmit(txt):
|
||||
#==================================================================#
|
||||
|
||||
def _generate(txt, minimum, maximum, found_entries):
|
||||
if(vars.full_determinism):
|
||||
torch.manual_seed(vars.seed)
|
||||
|
||||
gen_in = torch.tensor(txt, dtype=torch.long)[None]
|
||||
if(vars.sp is not None):
|
||||
soft_tokens = torch.arange(
|
||||
@ -4282,6 +4305,9 @@ def sendtocolab(txt, min, max):
|
||||
# Send text to TPU mesh transformer backend
|
||||
#==================================================================#
|
||||
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||
if(vars.full_determinism):
|
||||
tpu_mtj_backend.set_rng_seed(vars.seed)
|
||||
|
||||
vars.generated_tkns = 0
|
||||
|
||||
if(found_entries is None):
|
||||
@ -4567,6 +4593,7 @@ def refresh_settings():
|
||||
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
|
||||
|
||||
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
|
||||
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
|
||||
@ -5987,9 +6014,27 @@ def final_startup():
|
||||
},
|
||||
).start()
|
||||
|
||||
# Set the initial RNG seed
|
||||
if(vars.seed is not None):
|
||||
if(vars.use_colab_tpu):
|
||||
if(vars.seed_specified):
|
||||
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
|
||||
else:
|
||||
__import__("tpu_mtj_backend").randomize_rng_seed()
|
||||
else:
|
||||
if(vars.seed_specified):
|
||||
__import__("torch").manual_seed(vars.seed)
|
||||
else:
|
||||
__import__("torch").seed()
|
||||
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
|
||||
|
||||
def send_debug():
|
||||
if vars.debug:
|
||||
debug_info = ""
|
||||
try:
|
||||
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
|
||||
except:
|
||||
|
@ -230,6 +230,17 @@ gensettingstf = [
|
||||
"default": 0,
|
||||
"tooltip": "Disables userscript generation modifiers."
|
||||
},
|
||||
{
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
"label": "Full Determinism",
|
||||
"id": "setfulldeterminism",
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"step": 1,
|
||||
"default": 0,
|
||||
"tooltip": "Causes generation to be fully deterministic -- the model will always output the same thing as long as your story, settings and RNG seed are the same. If this is off, only the sequence of outputs that the model makes will be deterministic."
|
||||
},
|
||||
{
|
||||
"uitype": "toggle",
|
||||
"unit": "bool",
|
||||
|
@ -2606,6 +2606,9 @@ $(document).ready(function(){
|
||||
} else if(msg.cmd == "updatenogenmod") {
|
||||
// Update toggle state
|
||||
$("#setnogenmod").prop('checked', msg.data).change();
|
||||
} else if(msg.cmd == "updatefulldeterminism") {
|
||||
// Update toggle state
|
||||
$("#setfulldeterminism").prop('checked', msg.data).change();
|
||||
} else if(msg.cmd == "runs_remotely") {
|
||||
remote = true;
|
||||
hide([button_savetofile, button_import, button_importwi]);
|
||||
|
@ -56,6 +56,22 @@ from mesh_transformer.util import to_bf16
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
__seed = random.randrange(sys.maxsize)
|
||||
rng = random.Random(__seed)
|
||||
|
||||
|
||||
def get_rng_seed():
|
||||
return __seed
|
||||
|
||||
def set_rng_seed(seed: int):
|
||||
global __seed, rng
|
||||
rng = random.Random(seed)
|
||||
__seed = seed
|
||||
return seed
|
||||
|
||||
def randomize_rng_seed():
|
||||
return set_rng_seed(random.randrange(sys.maxsize))
|
||||
|
||||
|
||||
def warper_callback(logits) -> np.array:
|
||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||
@ -728,7 +744,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
assert not return_logits
|
||||
assert gen_length.ndim == 1
|
||||
assert soft_embeddings is not None
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
|
||||
@ -776,7 +792,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
return sample_data, n_generated, regeneration_required, halt
|
||||
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
||||
assert not return_logits
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
started_compiling_callback()
|
||||
|
Loading…
x
Reference in New Issue
Block a user