diff --git a/aiserver.py b/aiserver.py index 8328ac0c..4f533a7d 100644 --- a/aiserver.py +++ b/aiserver.py @@ -157,6 +157,7 @@ class vars: spmeta = None # Metadata of current soft prompt, or None if not using a soft prompt sp = None # Current soft prompt tensor (as a NumPy array) sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt + has_genmod = False # Whether or not at least one loaded Lua userscript has a generation modifier svowname = "" # Filename that was flagged for overwrite confirm saveow = False # Whether or not overwrite confirm has been displayed genseqs = [] # Temporary storage for generated sequences @@ -184,6 +185,7 @@ class vars: remote = False nopromptgen = False rngpersist = False + nogenmod = False #==================================================================# # Function to get model selection at startup @@ -1062,19 +1064,6 @@ else: vars.allowsp = True vars.modeldim = int(tpu_mtj_backend.params["d_model"]) tokenizer = tpu_mtj_backend.tokenizer - soft_tokens = tpumtjgetsofttokens() - threading.Thread( # Compile backend code in background - target=tpu_mtj_backend.infer, - args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), - kwargs={ - "soft_embeddings": vars.sp, - "soft_tokens": soft_tokens, - "use_callback": False, - "gen_len": 1, - "numseqs": vars.numseqs, - "excluded_world_info": list(set() for _ in range(vars.numseqs)), - }, - ).start() # Set up Flask routes @app.route('/') @@ -1182,13 +1171,18 @@ def load_lua_scripts(): modulenames.append(lst[i]["modulename"]) descriptions.append(lst[i]["description"]) + vars.has_genmod = False + try: vars.lua_koboldbridge.obliterate_multiverse() tpool.execute(vars.lua_koboldbridge.load_corescript, vars.corescript) - tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions) + vars.has_genmod = tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions) vars.lua_running = True except lupa.LuaError as e: - vars.lua_koboldbridge.obliterate_multiverse() + try: + vars.lua_koboldbridge.obliterate_multiverse() + except: + pass vars.lua_running = False if(vars.serverstarted): emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error, please check console.'}, broadcast=True) @@ -2035,6 +2029,10 @@ def get_message(msg): vars.rngpersist = msg['data'] settingschanged() refresh_settings() + elif(msg['cmd'] == 'setnogenmod'): + vars.nogenmod = msg['data'] + settingschanged() + refresh_settings() elif(not vars.remote and msg['cmd'] == 'importwi'): wiimportrequest() @@ -2105,6 +2103,8 @@ def savesettings(): js["dynamicscan"] = vars.dynamicscan js["nopromptgen"] = vars.nopromptgen js["rngpersist"] = vars.rngpersist + js["nogenmod"] = vars.nogenmod + js["antemplate"] = vars.setauthornotetemplate js["userscripts"] = vars.userscripts @@ -2170,6 +2170,8 @@ def loadsettings(): vars.nopromptgen = js["nopromptgen"] if("rngpersist" in js): vars.rngpersist = js["rngpersist"] + if("nogenmod" in js): + vars.nogenmod = js["nogenmod"] if("antemplate" in js): vars.setauthornotetemplate = js["antemplate"] @@ -2952,16 +2954,62 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): # Submit input text to generator try: - context = np.tile(np.uint32(txt), (vars.numseqs, 1)) soft_tokens = tpumtjgetsofttokens() global past - past = np.empty((vars.numseqs, 0), dtype=np.uint32) - while(True): - genout, n_generated, regeneration_required, halt = tpool.execute( - tpu_mtj_backend.infer, - context, + if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): + + context = np.tile(np.uint32(txt), (vars.numseqs, 1)) + past = np.empty((vars.numseqs, 0), dtype=np.uint32) + + while(True): + genout, n_generated, regeneration_required, halt = tpool.execute( + tpu_mtj_backend.infer_dynamic, + context, + gen_len = maximum-minimum+1, + temp=vars.temp, + top_p=vars.top_p, + top_k=vars.top_k, + tfs=vars.tfs, + numseqs=vars.numseqs, + repetition_penalty=vars.rep_pen, + soft_embeddings=vars.sp, + soft_tokens=soft_tokens, + excluded_world_info=found_entries, + ) + + past = np.pad(past, ((0, 0), (0, n_generated))) + for r in range(vars.numseqs): + for c in range(vars.lua_koboldbridge.generated_cols): + assert vars.lua_koboldbridge.generated[r+1][c+1] is not None + past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] + + if(halt or not regeneration_required): + break + print("(regeneration triggered)") + + encoded = [] + for i in range(vars.numseqs): + txt = tokenizer.decode(past[i]) + winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) + found_entries[i].update(_found_entries) + txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) + encoded.append(np.array(txt, dtype=np.uint32)) + max_length = len(max(encoded, key=len)) + encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) + context = np.concatenate( + ( + encoded, + past, + ), + axis=-1, + ) + + else: + genout = tpool.execute( + tpu_mtj_backend.infer_static, + np.uint32(txt), gen_len = maximum-minimum+1, temp=vars.temp, top_p=vars.top_p, @@ -2971,35 +3019,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): repetition_penalty=vars.rep_pen, soft_embeddings=vars.sp, soft_tokens=soft_tokens, - excluded_world_info=found_entries, ) - - past = np.pad(past, ((0, 0), (0, n_generated))) - for r in range(vars.numseqs): - for c in range(vars.lua_koboldbridge.generated_cols): - assert vars.lua_koboldbridge.generated[r+1][c+1] is not None - past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] - - if(halt or not regeneration_required): - break - print("(regeneration triggered)") - - encoded = [] + past = genout for i in range(vars.numseqs): - txt = tokenizer.decode(past[i]) - winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) - found_entries[i].update(_found_entries) - txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) - encoded.append(np.array(txt, dtype=np.uint32)) - max_length = len(max(encoded, key=len)) - encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) - context = np.concatenate( - ( - encoded, - past, - ), - axis=-1, - ) + vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) except Exception as e: if(issubclass(type(e), lupa.LuaError)): @@ -3181,6 +3204,7 @@ def refresh_settings(): emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True) emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) + emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True) emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) @@ -4434,6 +4458,34 @@ def randomGameRequest(topic, memory=""): loadmodelsettings() loadsettings() +# Precompile TPU backend if required +if(vars.model in ("TPUMeshTransformerGPTJ",)): + soft_tokens = tpumtjgetsofttokens() + if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): + threading.Thread( + target=tpu_mtj_backend.infer_dynamic, + args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), + kwargs={ + "soft_embeddings": vars.sp, + "soft_tokens": soft_tokens, + "gen_len": 1, + "use_callback": False, + "numseqs": vars.numseqs, + "excluded_world_info": list(set() for _ in range(vars.numseqs)), + }, + ).start() + else: + threading.Thread( + target=tpu_mtj_backend.infer_static, + args=(np.uint32((23403, 727, 20185)),), + kwargs={ + "soft_embeddings": vars.sp, + "soft_tokens": soft_tokens, + "gen_len": 1, + "numseqs": vars.numseqs, + }, + ).start() + #==================================================================# # Final startup commands to launch Flask app #==================================================================# diff --git a/bridge.lua b/bridge.lua index d5adbd01..85e7eacc 100644 --- a/bridge.lua +++ b/bridge.lua @@ -1851,13 +1851,14 @@ return function(_python, _bridged) -- API for aiserver.py --========================================================================== - ---@return nil + ---@return boolean function koboldbridge.load_userscripts(filenames, modulenames, descriptions) config_files = {} config_file_filename_map = {} koboldbridge.userscripts = {} koboldbridge.userscriptmodule_filename_map = {} koboldbridge.num_userscripts = 0 + local has_genmod = false for i, filename in _python.enumerate(filenames) do bridged.load_callback(filename, modulenames[i]) koboldbridge.logging_name = modulenames[i] @@ -1865,12 +1866,15 @@ return function(_python, _bridged) local f, err = old_loadfile(join_folder_and_filename(bridged.userscript_path, filename), "t", koboldbridge.get_universe(filename)) if err ~= nil then error(err) - return + return false end ---@type KoboldUserScript local _userscript = f() koboldbridge.logging_name = nil koboldbridge.filename = nil + if _userscript.genmod ~= nil then + has_genmod = true + end local userscript = deepcopy(KoboldUserScriptModule) rawset(userscript, "_inmod", function() koboldbridge.logging_name = modulenames[i] @@ -1903,6 +1907,7 @@ return function(_python, _bridged) koboldbridge.userscriptmodule_filename_map[userscript] = filename koboldbridge.num_userscripts = i + 1 end + return has_genmod end ---@return nil @@ -1949,7 +1954,9 @@ return function(_python, _bridged) koboldbridge.userstate = "genmod" if koboldbridge.genmod ~= nil then local _generated = deepcopy(koboldbridge.generated) - r = koboldbridge.genmod() + if not bridged.vars.nogenmod then + r = koboldbridge.genmod() + end setmetatable(koboldbridge.logits, nil) for kr, vr in old_next, koboldbridge.logits, nil do setmetatable(vr, nil) diff --git a/gensettings.py b/gensettings.py index fa4e14ef..a0e1b636 100644 --- a/gensettings.py +++ b/gensettings.py @@ -162,6 +162,17 @@ gensettingstf = [{ "step": 1, "default": 0, "tooltip": "When enabled, the Memory text box in the Random Story dialog will be prefilled by default with your current story's memory instead of being empty." + }, + { + "uitype": "toggle", + "unit": "bool", + "label": "No Genmod", + "id": "setnogenmod", + "min": 0, + "max": 1, + "step": 1, + "default": 0, + "tooltip": "Disables userscript generation modifiers." }] gensettingsik =[{ diff --git a/static/application.js b/static/application.js index f97333d6..403a9d8b 100644 --- a/static/application.js +++ b/static/application.js @@ -2194,6 +2194,9 @@ $(document).ready(function(){ if(!$("#setrngpersist").prop("checked")) { $("#rngmemory").val(""); } + } else if(msg.cmd == "updatenogenmod") { + // Update toggle state + $("#setnogenmod").prop('checked', msg.data).change(); } else if(msg.cmd == "runs_remotely") { remote = true; hide([button_savetofile, button_import, button_importwi]); diff --git a/templates/index.html b/templates/index.html index 08357979..8bab5ea6 100644 --- a/templates/index.html +++ b/templates/index.html @@ -17,7 +17,7 @@ - + diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 67196645..c68af60c 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -60,7 +60,7 @@ def __batch_xmap(shard_dim=1): return inner -def apply_repetition_penalty(logits, tokens, repetition_penalty): +def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty): ''' This gets called by generate_loop_fn to apply repetition penalty to the 1D array logits using the provided 1D array of tokens to penalize @@ -85,7 +85,7 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): logits[tokens] = penalty_logits return logits -def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): +def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): ''' This gets called by generate_loop_fn to apply a series of 4 filters to the logits (top-k, then top-p, then TFS, then temperature) before @@ -183,6 +183,127 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): # probability distribution) return jax.random.categorical(key, logits, -1).astype(np.uint32) +def apply_repetition_penalty_static(logits, tokens, repetition_penalty): + ''' + This gets called by generate_loop_fn to apply repetition penalty + to the 1D array logits using the provided 1D array of tokens to penalize + ''' + # Make a new array with the same length as the tokens array but with + # each element replaced by the value at the corresponding index in the + # logits array; e.g. + # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], + # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] + penalty_logits = jnp.take(logits, tokens) + # Divide positive values by repetition_penalty and multiply negative + # values by repetition_penalty (the academic publication that described + # this technique actually just only divided, but that would cause tokens + # with negative logits to become more likely, which is obviously wrong) + penalty_logits = jnp.where( + penalty_logits > 0, + penalty_logits/repetition_penalty, + penalty_logits*repetition_penalty, + ) + # Finally, put those penalized logit values back into their original + # positions in the logits array + return logits.at[tokens].set(penalty_logits) + +def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): + ''' + This gets called by generate_loop_fn to apply a series of 4 filters + to the logits (top-k, then top-p, then TFS, then temperature) before + picking one token using the modified logits + ''' + # Top-k (keep only the k tokens with the highest logits and remove + # the rest, by setting their logits to negative infinity) + def top_k_filter(logits): + # After sorting the logits array in descending order, + # sorted_indices_to_remove is a 1D array that is True for tokens + # in the sorted logits array we want to remove and False for ones + # we want to keep, in this case the first top_k elements will be + # False and the rest will be True + sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k + # Unsort the logits array back to its original configuration and + # remove tokens we need to remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-logits), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, logits) + logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) + # Top-p (after sorting the remaining tokens again in descending order of + # logit, remove the ones that have cumulative softmax probability + # greater than p) + def top_p_filter(logits): + # Sort the logits array in descending order, replace every element + # with e (Euler's number) to the power of that element, and divide + # each element of the new array by the sum of the elements in the + # new array + sorted_logits = -jnp.sort(-logits) + probabilities = jax.nn.softmax(sorted_logits) + # Calculate cumulative_probabilities as the prefix-sum array of + # probabilities + cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) + # We want to remove tokens with cumulative probability higher + # than top_p + sorted_indices_to_remove = cumulative_probabilities > top_p + # Don't ever remove the token with the highest logit, even if + # the probability is higher than top_p + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-logits), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, logits) + logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) + # Tail free sampling (basically top-p a second time on remaining tokens + # except it's the "cumulative normalized absolute second finite + # differences of the softmax probabilities" instead of just the + # cumulative softmax probabilities) + def tail_free_filter(logits): + # Sort in descending order + sorted_logits = -jnp.sort(-logits) + # Softmax again + probabilities = jax.nn.softmax(sorted_logits) + # Calculate the second finite differences of that array (i.e. + # calculate the difference array and then calculate the difference + # array of the difference array) + d2 = jnp.diff(jnp.diff(probabilities)) + # Get the absolute values of all those second finite differences + d2 = jnp.abs(d2) + # Normalize (all elements in the array are divided by the sum of the + # array's elements) + d2 = d2 / d2.sum(axis=-1, keepdims=True) + # Get the prefix-sum array + cumulative_d2 = jnp.cumsum(d2, axis=-1) + # We will remove the tokens with a cumulative normalized absolute + # second finite difference larger than the TFS value + sorted_indices_to_remove = cumulative_d2 > tfs + # Don't remove the token with the highest logit + sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) + # Since the d2 array has two fewer elements than the logits array, + # we'll add two extra Trues to the end + sorted_indices_to_remove = jnp.pad( + sorted_indices_to_remove, + (0, 2), + constant_values=True, + ) + # Unsort and remove + _, indices_to_remove = jax.lax.sort_key_val( + jnp.argsort(-logits), + sorted_indices_to_remove, + ) + return jnp.where(indices_to_remove, -jnp.inf, logits) + logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) + # Temperature (just divide the logits by the temperature) + def temp_filter(logits): + return logits / temp + logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) + # Finally, pick one token using the softmax thingy again (it gives + # an array whose elements sum to 1 so it can be used nicely as a + # probability distribution) + return jax.random.categorical(key, logits, -1).astype(jnp.uint32) + pad_token_id = 50256 def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): @@ -192,11 +313,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op generated, generated_index, logits, _ = carry[0][0] sample_key = carry[1] # Get the pseudo-random number generator key that will - # be used by kobold_sample to randomly pick a token + # be used by kobold_sample_dynamic to randomly pick a token sample_key, new_key = jax.random.split(sample_key, num=2) # Apply repetition penalty to all tokens that are # currently inside the "generated" array - logits = apply_repetition_penalty( + logits = apply_repetition_penalty_dynamic( logits, generated, repetition_penalty @@ -205,11 +326,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op # their logits to negative infinity which effectively # makes their probabilities of being chosen zero logits[badwords] = -np.inf - # Use the sampler (kobold_sample) to pick one token + # Use the sampler (kobold_sample_dynamic) to pick one token # based on the logits array as a 0D uint32 array # (higher logit means higher probability of being # picked, non-linearly) - next_token = kobold_sample( + next_token = kobold_sample_dynamic( sample_key, logits, **sampler_options, @@ -236,6 +357,100 @@ class PenalizingCausalTransformer(CausalTransformer): def __init__(self, config): # Initialize super().__init__(config) + def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): + numseqs = numseqs_aux.shape[0] + # These are the tokens that we don't want the AI to ever write + self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) + @hk.transform + def generate_sample(context, ctx_length): + # Give the initial context to the transformer + transformer = CausalTransformerShard(config) + def generate_initial_scan_fn(sequence_index, _): + _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) + # The "generated" array will contain the tokens from the + # context as well as the tokens picked by the sampler at + # each stage, padded with a bunch of 50256s, so we know + # which tokens have to be repetition penalized + generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens + generated_index = config["seq"] + # Add that information to generate_loop_fn's starting state + initial_state = (generated, generated_index, sequence_index) + initial_state + return sequence_index+1, initial_state + _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) + sample_key = initial_states[-1][0] + initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) + # Get repetition penalty from the arguments + repetition_penalty = sampler_options.pop('repetition_penalty', None) + # This is the main generation loop + def generate_loop_fn(carry): + # Unpack current generate_loop_fn state + generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] + sample_key = carry[1] + # Get the pseudo-random number generator key that will + # be used by kobold_sample_static to randomly pick a token + sample_key, new_key = jax.random.split(sample_key) + # Give the context to the model and get the logits it + # spits out + # (a 2D array with 1 row and 50400 columns representing + # how strongly it thinks each of the 50257 tokens in its + # vocabulary should be appended to the context, followed + # by 143 apparently useless columns ???) + logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings) + # Verify that logits does indeed have that many rows and + # columns (if you get an error here, pray for mercy) + assert logits.shape == (1, config["n_vocab"]) + # Flatten it into a 1D array to make it easier to use + logits = logits[0] + # Apply repetition penalty to all tokens that are + # currently inside the "generated" array + if repetition_penalty is not None: + logits = apply_repetition_penalty_static( + logits, + generated, + repetition_penalty + ) + # Remove any tokens in the badwords list by setting + # their logits to negative infinity which effectively + # makes their probabilities of being chosen zero + logits = logits.at[self.badwords].set(-jnp.inf) + # Use the sampler (kobold_sample_static) to pick one token + # based on the logits array as a 0D uint32 array + # (higher logit means higher probability of being + # picked, non-linearly) + next_token = kobold_sample_static( + sample_key, + logits, + **sampler_options, + ) + # Remember what token was picked + generated = generated.at[generated_index].set(next_token) + generated_index += 1 + # Re-pack the current generate_loop_fn's state so we can + # get back the same variables the next time + carry[0][0] = (generated, generated_index, sequence_index, next_token[jnp.newaxis], new_state) + carry[0].append(carry[0].pop(0)) + return carry[0], new_key + return jax.lax.while_loop( + lambda carry: carry[0][0][1] - config["seq"] < gen_length, + generate_loop_fn, + (initial_states, sample_key), + ) + return generate_sample.apply(state["params"], key, ctx, ctx_length) + self.generate_static_xmap = jax.experimental.maps.xmap( + fun=generate_static, + in_axes=( + ["shard", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["batch", ...], + ["shard", ...], + ), + out_axes=["shard", "batch", ...], + axis_resources={'shard': 'mp', 'batch': 'dp'}, + ) def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): numseqs = numseqs_aux.shape[0] @hk.transform @@ -314,7 +529,7 @@ class PenalizingCausalTransformer(CausalTransformer): out_axes=["shard", "batch", ...], axis_resources={'shard': 'mp', 'batch': 'dp'}, ) - def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): + def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): assert excluded_world_info is not None assert not return_logits assert gen_length.ndim == 1 @@ -360,9 +575,24 @@ class PenalizingCausalTransformer(CausalTransformer): else: break return sample_data, n_generated, regeneration_required, halt + def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): + assert not return_logits + key = hk.PRNGSequence(random.randint(0, 2 ** 60)) + batch_size = ctx.shape[0] + self.batch_size = batch_size + return self.generate_static_xmap( + self.state, + jnp.array(key.take(batch_size)), + ctx, + np.array(ctx_length, dtype=np.uint32), + np.array(gen_length, dtype=np.uint32), + np.empty((batch_size, numseqs), dtype=np.uint8), + sampler_options, + soft_embeddings, + ) -def infer( +def infer_dynamic( context: np.array, top_p=0.9, temp=0.5, @@ -394,7 +624,7 @@ def infer( "repetition_penalty": float(repetition_penalty), "top_k": int(top_k), } - output = network.generate( + output = network.generate_dynamic( batched_tokens, np.ones(total_batch, dtype=np.uint32) * provided_ctx, np.ones(total_batch, dtype=np.uint32) * gen_len, @@ -408,6 +638,47 @@ def infer( samples.append(out[0][params["seq"] : params["seq"] + gen_len]) return (samples,) + output[1:] +def infer_static( + context: np.array, + top_p=0.9, + temp=0.5, + top_k=0, + tfs=1.0, + repetition_penalty=1.0, + numseqs=1, + gen_len=80, + soft_embeddings: Optional[np.array] = None, + soft_tokens: Optional[np.array] = None, +) -> List[np.array]: + maps.thread_resources.env = thread_resources_env + total_batch = 1 + tokens = context + if(soft_tokens is not None): + tokens = np.uint32(np.concatenate((soft_tokens, tokens))) + provided_ctx = tokens.shape[0] + pad_amount = seq - provided_ctx + padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) + batched_tokens = np.array([padded_tokens] * total_batch) + samples = [] + batched_generator_params = { + "temp": temp * np.ones(total_batch), + "top_p": top_p * np.ones(total_batch), + "tfs": tfs * np.ones(total_batch), + "repetition_penalty": repetition_penalty * np.ones(total_batch), + "top_k": np.full(total_batch, top_k, dtype=np.uint32) + } + output = network.generate_static( + batched_tokens, + np.ones(total_batch, dtype=np.uint32) * provided_ctx, + np.ones(total_batch, dtype=np.uint32) * gen_len, + numseqs, + batched_generator_params, + soft_embeddings=soft_embeddings, + )[0] + for o in output: + samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) + return samples + def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: global thread_resources_env, seq, tokenizer, network, params