mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge pull request #58 from VE-FORBRYDERNE/xmap
Dynamic TPU backend xmaps
This commit is contained in:
		
							
								
								
									
										106
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										106
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -22,7 +22,7 @@ import packaging | ||||
| import contextlib | ||||
| import traceback | ||||
| import threading | ||||
| from typing import Any, Callable, TypeVar, Union, Dict, Set, List | ||||
| from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List | ||||
|  | ||||
| import requests | ||||
| import html | ||||
| @@ -993,7 +993,7 @@ else: | ||||
|                 -1, | ||||
|                 tpu_mtj_backend.params["d_model"], | ||||
|             ) | ||||
|             vars.sp = tensor | ||||
|             vars.sp = tpu_mtj_backend.shard_xmap(tensor) | ||||
|         soft_tokens = np.arange( | ||||
|             tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"], | ||||
|             tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length, | ||||
| @@ -1001,6 +1001,49 @@ else: | ||||
|         ) | ||||
|         return soft_tokens | ||||
|  | ||||
|     def tpumtjgenerate_warper_callback(scores) -> "np.array": | ||||
|         scores_shape = scores.shape | ||||
|         scores_list = scores.tolist() | ||||
|         vars.lua_koboldbridge.logits = vars.lua_state.table() | ||||
|         for r, row in enumerate(scores_list): | ||||
|             vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) | ||||
|         vars.lua_koboldbridge.vocab_size = scores_shape[-1] | ||||
|  | ||||
|         execute_genmod() | ||||
|  | ||||
|         scores = np.array( | ||||
|             tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), | ||||
|             dtype=scores.dtype, | ||||
|         ) | ||||
|         assert scores.shape == scores_shape | ||||
|  | ||||
|         return scores | ||||
|      | ||||
|     def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: | ||||
|         vars.generated_tkns += 1 | ||||
|  | ||||
|         assert len(excluded_world_info) == len(generated) | ||||
|         regeneration_required = vars.lua_koboldbridge.regeneration_required | ||||
|         halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt | ||||
|         vars.lua_koboldbridge.regeneration_required = False | ||||
|  | ||||
|         global past | ||||
|  | ||||
|         for i in range(vars.numseqs): | ||||
|             vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item()) | ||||
|  | ||||
|         if(not vars.dynamicscan or halt): | ||||
|             return excluded_world_info, regeneration_required, halt | ||||
|  | ||||
|         for i, t in enumerate(generated): | ||||
|             decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]) | ||||
|             _, found = checkworldinfo(decoded, force_use_txt=True) | ||||
|             found -= excluded_world_info[i] | ||||
|             if(len(found) != 0): | ||||
|                 regeneration_required = True | ||||
|                 break | ||||
|         return excluded_world_info, regeneration_required, halt | ||||
|  | ||||
|     # If we're running Colab or OAI, we still need a tokenizer. | ||||
|     if(vars.model == "Colab"): | ||||
|         from transformers import GPT2TokenizerFast | ||||
| @@ -1013,6 +1056,8 @@ else: | ||||
|         print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END)) | ||||
|         assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth) | ||||
|         import tpu_mtj_backend | ||||
|         tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback | ||||
|         tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback | ||||
|         tpu_mtj_backend.load_model(vars.custmodpth) | ||||
|         vars.allowsp = True | ||||
|         vars.modeldim = int(tpu_mtj_backend.params["d_model"]) | ||||
| @@ -1020,12 +1065,14 @@ else: | ||||
|         soft_tokens = tpumtjgetsofttokens() | ||||
|         threading.Thread(  # Compile backend code in background | ||||
|             target=tpu_mtj_backend.infer, | ||||
|             args=(np.uint32((23403, 727, 20185)),), | ||||
|             args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), | ||||
|             kwargs={ | ||||
|                 "soft_embeddings": vars.sp, | ||||
|                 "soft_tokens": soft_tokens, | ||||
|                 "use_callback": False, | ||||
|                 "gen_len": 1, | ||||
|                 "numseqs": vars.numseqs, | ||||
|                 "excluded_world_info": list(set() for _ in range(vars.numseqs)), | ||||
|             }, | ||||
|         ).start() | ||||
|  | ||||
| @@ -2890,22 +2937,31 @@ def sendtocolab(txt, min, max): | ||||
| #  Send text to TPU mesh transformer backend | ||||
| #==================================================================# | ||||
| def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | ||||
|     vars.generated_tkns = 0 | ||||
|  | ||||
|     if(found_entries is None): | ||||
|         found_entries = set() | ||||
|     found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs)) | ||||
|  | ||||
|     print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END)) | ||||
|  | ||||
|     vars._actions = vars.actions | ||||
|     vars._prompt = vars.prompt | ||||
|     if(vars.dynamicscan): | ||||
|         vars._actions = vars._actions.copy() | ||||
|  | ||||
|     # Submit input text to generator | ||||
|     try: | ||||
|         if(vars.dynamicscan): | ||||
|             raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet") | ||||
|  | ||||
|         context = np.tile(np.uint32(txt), (vars.numseqs, 1)) | ||||
|         soft_tokens = tpumtjgetsofttokens() | ||||
|  | ||||
|         genout = tpool.execute( | ||||
|         global past | ||||
|         past = np.empty((vars.numseqs, 0), dtype=np.uint32) | ||||
|  | ||||
|         while(True): | ||||
|             genout, n_generated, regeneration_required, halt = tpool.execute( | ||||
|                 tpu_mtj_backend.infer, | ||||
|             np.uint32(txt), | ||||
|                 context, | ||||
|                 gen_len = maximum-minimum+1, | ||||
|                 temp=vars.temp, | ||||
|                 top_p=vars.top_p, | ||||
| @@ -2915,6 +2971,34 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | ||||
|                 repetition_penalty=vars.rep_pen, | ||||
|                 soft_embeddings=vars.sp, | ||||
|                 soft_tokens=soft_tokens, | ||||
|                 excluded_world_info=found_entries, | ||||
|             ) | ||||
|  | ||||
|             past = np.pad(past, ((0, 0), (0, n_generated))) | ||||
|             for r in range(vars.numseqs): | ||||
|                 for c in range(vars.lua_koboldbridge.generated_cols): | ||||
|                     assert vars.lua_koboldbridge.generated[r+1][c+1] is not None | ||||
|                     past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1] | ||||
|  | ||||
|             if(halt or not regeneration_required): | ||||
|                 break | ||||
|             print("(regeneration triggered)") | ||||
|  | ||||
|             encoded = [] | ||||
|             for i in range(vars.numseqs): | ||||
|                 txt = tokenizer.decode(past[i]) | ||||
|                 winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True) | ||||
|                 found_entries[i].update(_found_entries) | ||||
|                 txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt) | ||||
|                 encoded.append(np.array(txt, dtype=np.uint32)) | ||||
|             max_length = len(max(encoded, key=len)) | ||||
|             encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded)) | ||||
|             context = np.concatenate( | ||||
|                 ( | ||||
|                     encoded, | ||||
|                     past, | ||||
|                 ), | ||||
|                 axis=-1, | ||||
|             ) | ||||
|  | ||||
|     except Exception as e: | ||||
| @@ -2933,8 +3017,8 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | ||||
|         return | ||||
|  | ||||
|     for i in range(vars.numseqs): | ||||
|         vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) | ||||
|         vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i]) | ||||
|         vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i]) | ||||
|     genout = past | ||||
|  | ||||
|     execute_outmod() | ||||
|     if(vars.lua_koboldbridge.regeneration_required): | ||||
| @@ -4005,7 +4089,7 @@ def spRequest(filename): | ||||
|             -1, | ||||
|             tpu_mtj_backend.params["d_model"], | ||||
|         ) | ||||
|         vars.sp = np.float32(tensor) | ||||
|         vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor)) | ||||
|     else: | ||||
|         vars.sp = torch.from_numpy(tensor) | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| import multiprocessing | ||||
| from typing import Any, Dict, List, Optional | ||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar | ||||
| import progressbar | ||||
| import time | ||||
| import os | ||||
| @@ -20,6 +20,13 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor | ||||
| params: Dict[str, Any] = {} | ||||
|  | ||||
|  | ||||
| def warper_callback(logits) -> np.array: | ||||
|     raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined") | ||||
|  | ||||
| def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]: | ||||
|     raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined") | ||||
|  | ||||
|  | ||||
| def show_spinner(): | ||||
|     bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), '  ', progressbar.BouncingBar(left='[', right=']', marker='█')]) | ||||
|     i = 0 | ||||
| @@ -28,6 +35,31 @@ def show_spinner(): | ||||
|         time.sleep(0.1) | ||||
|         i += 1 | ||||
|  | ||||
|  | ||||
| __F = TypeVar("__F", bound=Callable) | ||||
| __T = TypeVar("__T") | ||||
|  | ||||
| def __move_xmap(f: __F, out_axis: str) -> __F: | ||||
|     return maps.xmap( | ||||
|         f, | ||||
|         in_axes=(["shard", ...], ["batch", ...]), | ||||
|         out_axes=[out_axis, ...], | ||||
|         axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||
|     ) | ||||
|  | ||||
| def __shard_xmap(batch_dim=1): | ||||
|     xmap = __move_xmap(lambda s, b: s, "shard") | ||||
|     def inner(x: __T) -> __T: | ||||
|         return xmap(x, np.empty(batch_dim)) | ||||
|     return inner | ||||
|  | ||||
| def __batch_xmap(shard_dim=1): | ||||
|     xmap = __move_xmap(lambda s, b: b, "batch") | ||||
|     def inner(x: __T) -> __T: | ||||
|         return xmap(np.empty(shard_dim), x) | ||||
|     return inner | ||||
|  | ||||
|  | ||||
| def apply_repetition_penalty(logits, tokens, repetition_penalty): | ||||
|     ''' | ||||
|     This gets called by generate_loop_fn to apply repetition penalty | ||||
| @@ -38,19 +70,20 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): | ||||
|     # logits array; e.g. | ||||
|     # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], | ||||
|     # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] | ||||
|     penalty_logits = jnp.take(logits, tokens) | ||||
|     penalty_logits = np.take(logits, tokens) | ||||
|     # Divide positive values by repetition_penalty and multiply negative | ||||
|     # values by repetition_penalty (the academic publication that described | ||||
|     # this technique actually just only divided, but that would cause tokens | ||||
|     # with negative logits to become more likely, which is obviously wrong) | ||||
|     penalty_logits = jnp.where( | ||||
|     penalty_logits = np.where( | ||||
|         penalty_logits > 0, | ||||
|         penalty_logits/repetition_penalty, | ||||
|         penalty_logits*repetition_penalty, | ||||
|     ) | ||||
|     # Finally, put those penalized logit values back into their original | ||||
|     # positions in the logits array | ||||
|     return logits.at[tokens].set(penalty_logits) | ||||
|     logits[tokens] = penalty_logits | ||||
|     return logits | ||||
|  | ||||
| def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||
|     ''' | ||||
| @@ -66,15 +99,16 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||
|         # in the sorted logits array we want to remove and False for ones | ||||
|         # we want to keep, in this case the first top_k elements will be | ||||
|         # False and the rest will be True | ||||
|         sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k | ||||
|         sorted_indices_to_remove = np.arange(len(logits)) >= top_k | ||||
|         # Unsort the logits array back to its original configuration and | ||||
|         # remove tokens we need to remove | ||||
|         _, indices_to_remove = jax.lax.sort_key_val( | ||||
|             jnp.argsort(-logits), | ||||
|             np.argsort(-logits), | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if top_k > 0: | ||||
|         logits = top_k_filter(logits) | ||||
|     # Top-p (after sorting the remaining tokens again in descending order of | ||||
|     # logit, remove the ones that have cumulative softmax probability | ||||
|     # greater than p) | ||||
| @@ -83,109 +117,167 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||
|         # with e (Euler's number) to the power of that element, and divide | ||||
|         # each element of the new array by the sum of the elements in the | ||||
|         # new array | ||||
|         sorted_logits = -jnp.sort(-logits) | ||||
|         probabilities = jax.nn.softmax(sorted_logits) | ||||
|         sorted_logits = -np.sort(-logits) | ||||
|         probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) | ||||
|         # Calculate cumulative_probabilities as the prefix-sum array of | ||||
|         # probabilities | ||||
|         cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) | ||||
|         cumulative_probabilities = np.cumsum(probabilities, axis=-1) | ||||
|         # We want to remove tokens with cumulative probability higher | ||||
|         # than top_p | ||||
|         sorted_indices_to_remove = cumulative_probabilities > top_p | ||||
|         # Don't ever remove the token with the highest logit, even if | ||||
|         # the probability is higher than top_p | ||||
|         sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) | ||||
|         sorted_indices_to_remove[0] = False | ||||
|         # Unsort and remove | ||||
|         _, indices_to_remove = jax.lax.sort_key_val( | ||||
|             jnp.argsort(-logits), | ||||
|             np.argsort(-logits), | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if top_p < 1.0: | ||||
|         logits = top_p_filter(logits) | ||||
|     # Tail free sampling (basically top-p a second time on remaining tokens | ||||
|     # except it's the "cumulative normalized absolute second finite | ||||
|     # differences of the softmax probabilities" instead of just the | ||||
|     # cumulative softmax probabilities) | ||||
|     def tail_free_filter(logits): | ||||
|         # Sort in descending order | ||||
|         sorted_logits = -jnp.sort(-logits) | ||||
|         sorted_logits = -np.sort(-logits) | ||||
|         # Softmax again | ||||
|         probabilities = jax.nn.softmax(sorted_logits) | ||||
|         probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True) | ||||
|         # Calculate the second finite differences of that array (i.e. | ||||
|         # calculate the difference array and then calculate the difference | ||||
|         # array of the difference array) | ||||
|         d2 = jnp.diff(jnp.diff(probabilities)) | ||||
|         d2 = np.diff(np.diff(probabilities)) | ||||
|         # Get the absolute values of all those second finite differences | ||||
|         d2 = jnp.abs(d2) | ||||
|         d2 = np.abs(d2) | ||||
|         # Normalize (all elements in the array are divided by the sum of the | ||||
|         # array's elements) | ||||
|         d2 = d2 / d2.sum(axis=-1, keepdims=True) | ||||
|         # Get the prefix-sum array | ||||
|         cumulative_d2 = jnp.cumsum(d2, axis=-1) | ||||
|         cumulative_d2 = np.cumsum(d2, axis=-1) | ||||
|         # We will remove the tokens with a cumulative normalized absolute | ||||
|         # second finite difference larger than the TFS value | ||||
|         sorted_indices_to_remove = cumulative_d2 > tfs | ||||
|         # Don't remove the token with the highest logit | ||||
|         sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) | ||||
|         sorted_indices_to_remove[0] = False | ||||
|         # Since the d2 array has two fewer elements than the logits array, | ||||
|         # we'll add two extra Trues to the end | ||||
|         sorted_indices_to_remove = jnp.pad( | ||||
|         sorted_indices_to_remove = np.pad( | ||||
|             sorted_indices_to_remove, | ||||
|             (0, 2), | ||||
|             constant_values=True, | ||||
|         ) | ||||
|         # Unsort and remove | ||||
|         _, indices_to_remove = jax.lax.sort_key_val( | ||||
|             jnp.argsort(-logits), | ||||
|             np.argsort(-logits), | ||||
|             sorted_indices_to_remove, | ||||
|         ) | ||||
|         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||
|     logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) | ||||
|         return np.where(indices_to_remove, -np.inf, logits) | ||||
|     if tfs < 1.0: | ||||
|         logits = tail_free_filter(logits) | ||||
|     # Temperature (just divide the logits by the temperature) | ||||
|     def temp_filter(logits): | ||||
|         return logits / temp | ||||
|     logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) | ||||
|     logits /= temp | ||||
|     # Finally, pick one token using the softmax thingy again (it gives | ||||
|     # an array whose elements sum to 1 so it can be used nicely as a | ||||
|     # probability distribution) | ||||
|     return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis] | ||||
|     return jax.random.categorical(key, logits, -1).astype(np.uint32) | ||||
|  | ||||
| pad_token_id = 50256 | ||||
|  | ||||
| def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): | ||||
|     numseqs = numseqs_aux.shape[0] | ||||
|     gi = data[0][1] | ||||
|     def sample_loop_fn(carry): | ||||
|         generated, generated_index, logits, _ = carry[0][0] | ||||
|         sample_key = carry[1] | ||||
|         # Get the pseudo-random number generator key that will | ||||
|         # be used by kobold_sample to randomly pick a token | ||||
|         sample_key, new_key = jax.random.split(sample_key, num=2) | ||||
|         # Apply repetition penalty to all tokens that are | ||||
|         # currently inside the "generated" array | ||||
|         logits = apply_repetition_penalty( | ||||
|             logits, | ||||
|             generated, | ||||
|             repetition_penalty | ||||
|         ) | ||||
|         # Remove any tokens in the badwords list by setting | ||||
|         # their logits to negative infinity which effectively | ||||
|         # makes their probabilities of being chosen zero | ||||
|         logits[badwords] = -np.inf | ||||
|         # Use the sampler (kobold_sample) to pick one token | ||||
|         # based on the logits array as a 0D uint32 array | ||||
|         # (higher logit means higher probability of being | ||||
|         # picked, non-linearly) | ||||
|         next_token = kobold_sample( | ||||
|             sample_key, | ||||
|             logits, | ||||
|             **sampler_options, | ||||
|         ) | ||||
|         # Remember what token was picked | ||||
|         generated[generated_index] = next_token | ||||
|         generated_index += 1 | ||||
|         # Re-pack the current sample_loop_fn's state so we can | ||||
|         # get back the same variables the next time | ||||
|         carry[0][0] = [generated, generated_index, logits, next_token] | ||||
|         carry[0].append(carry[0].pop(0)) | ||||
|         return carry[0], new_key | ||||
|     # return jax.lax.while_loop( | ||||
|     #     lambda carry: carry[0][0][1] == gi, | ||||
|     #     sample_loop_fn, | ||||
|     #     (data, key), | ||||
|     # ) | ||||
|     carry = (data, key) | ||||
|     while carry[0][0][1] == gi: | ||||
|         carry = sample_loop_fn(carry) | ||||
|     return carry | ||||
|  | ||||
| class PenalizingCausalTransformer(CausalTransformer): | ||||
|     def __init__(self, config): | ||||
|         # Initialize | ||||
|         super().__init__(config) | ||||
|         def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): | ||||
|         def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): | ||||
|             numseqs = numseqs_aux.shape[0] | ||||
|             # These are the tokens that we don't want the AI to ever write | ||||
|             self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) | ||||
|             def generate_sample(context, ctx_length): | ||||
|             @hk.transform | ||||
|             def generate_initial_inner(context, ctx_length): | ||||
|                 # Give the initial context to the transformer | ||||
|                 transformer = CausalTransformerShard(config) | ||||
|                 def generate_initial_scan_fn(sequence_index, _): | ||||
|                     _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) | ||||
|                     # The "generated" array will contain the tokens from the | ||||
|                     # context as well as the tokens picked by the sampler at | ||||
|                     # each stage, padded with a bunch of 50256s, so we know | ||||
|                     # which tokens have to be repetition penalized | ||||
|                     generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id)  # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens | ||||
|                 def generate_initial_scan_fn(sequence_index, c): | ||||
|                     _, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings) | ||||
|                     generated_index = config["seq"] | ||||
|                     # Add that information to generate_loop_fn's starting state | ||||
|                     initial_state = (generated, generated_index, sequence_index) + initial_state | ||||
|                     initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state | ||||
|                     return sequence_index+1, initial_state | ||||
|                 _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) | ||||
|                 _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs) | ||||
|                 sample_key = initial_states[-1][0] | ||||
|                 initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) | ||||
|                 # Get repetition penalty from the arguments | ||||
|                 repetition_penalty = sampler_options.pop('repetition_penalty', None) | ||||
|                 initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs)) | ||||
|                 return initial_states, sample_key | ||||
|             return generate_initial_inner.apply(state["params"], key, ctx, ctx_length) | ||||
|         self.generate_initial_xmap = jax.experimental.maps.xmap( | ||||
|             fun=generate_initial, | ||||
|             in_axes=( | ||||
|                 ["shard", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["shard", ...], | ||||
|             ), | ||||
|             out_axes=["shard", "batch", ...], | ||||
|             axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||
|         ) | ||||
|         def generate_once(data, state, numseqs_aux, soft_embeddings=None): | ||||
|             numseqs = numseqs_aux.shape[0] | ||||
|             @hk.without_apply_rng | ||||
|             @hk.transform | ||||
|             def generate_once_inner(): | ||||
|                 gi = data[0][1] | ||||
|                 # Give the initial context to the transformer | ||||
|                 transformer = CausalTransformerShard(config) | ||||
|                 # This is the main generation loop | ||||
|                 def generate_loop_fn(carry): | ||||
|                     # Unpack current generate_loop_fn state | ||||
|                     generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] | ||||
|                     sample_key = carry[1] | ||||
|                     # Get the pseudo-random number generator key that will | ||||
|                     # be used by kobold_sample to randomly pick a token | ||||
|                     sample_key, new_key = jax.random.split(sample_key) | ||||
|                     _, generated_index, sequence_index, next_token, decode_state = carry[0][0] | ||||
|                     # Give the context to the model and get the logits it | ||||
|                     # spits out | ||||
|                     # (a 2D array with 1 row and 50400 columns representing | ||||
| @@ -196,75 +288,78 @@ class PenalizingCausalTransformer(CausalTransformer): | ||||
|                     # Verify that logits does indeed have that many rows and | ||||
|                     # columns (if you get an error here, pray for mercy) | ||||
|                     assert logits.shape == (1, config["n_vocab"]) | ||||
|                     assert logits.dtype == jnp.float32 | ||||
|                     # Flatten it into a 1D array to make it easier to use | ||||
|                     logits = logits[0] | ||||
|                     # Apply repetition penalty to all tokens that are | ||||
|                     # currently inside the "generated" array | ||||
|                     if repetition_penalty is not None: | ||||
|                         logits = apply_repetition_penalty( | ||||
|                             logits, | ||||
|                             generated, | ||||
|                             repetition_penalty | ||||
|                         ) | ||||
|                     # Remove any tokens in the badwords list by setting | ||||
|                     # their logits to negative infinity which effectively | ||||
|                     # makes their probabilities of being chosen zero | ||||
|                     logits = logits.at[self.badwords].set(-jnp.inf) | ||||
|                     # Use the sampler (kobold_sample) to pick one token | ||||
|                     # based on the logits array as a 1D array with 1 element | ||||
|                     # (higher logit means higher probability of being | ||||
|                     # picked, non-linearly) | ||||
|                     next_token = kobold_sample( | ||||
|                         sample_key, | ||||
|                         logits, | ||||
|                         **sampler_options, | ||||
|                     ) | ||||
|                     # Remember what token was picked | ||||
|                     generated = generated.at[generated_index].set(next_token[0]) | ||||
|                     generated_index += 1 | ||||
|                     # Re-pack the current generate_loop_fn's state so we can | ||||
|                     # get back the same variables the next time | ||||
|                     carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state) | ||||
|                     generated_index += 1 | ||||
|                     carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state] | ||||
|                     carry[0].append(carry[0].pop(0)) | ||||
|                     return carry[0], new_key | ||||
|                 final_state = jax.lax.while_loop( | ||||
|                     lambda carry: carry[0][0][1] - config["seq"] < gen_length, | ||||
|                     return carry[0], | ||||
|                 return jax.lax.while_loop( | ||||
|                     lambda carry: carry[0][0][1] == gi, | ||||
|                     generate_loop_fn, | ||||
|                     (initial_states, sample_key), | ||||
|                     (data,), | ||||
|                 ) | ||||
|                 return final_state | ||||
|             generate_fn = hk.transform(generate_sample).apply | ||||
|             return generate_fn(state["params"], key, ctx, ctx_length) | ||||
|         self.generate_xmap = jax.experimental.maps.xmap( | ||||
|             fun=generate, | ||||
|             return generate_once_inner.apply(state["params"]) | ||||
|         self.generate_once_xmap = jax.experimental.maps.xmap( | ||||
|             fun=generate_once, | ||||
|             in_axes=( | ||||
|                 ["shard", "batch", ...], | ||||
|                 ["shard", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["batch", ...], | ||||
|                 ["shard", ...], | ||||
|             ), | ||||
|             out_axes=["shard", "batch", ...], | ||||
|             axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||
|         ) | ||||
|     def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): | ||||
|     def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): | ||||
|         assert excluded_world_info is not None | ||||
|         assert not return_logits | ||||
|         assert gen_length.ndim == 1 | ||||
|         assert soft_embeddings is not None | ||||
|         key = hk.PRNGSequence(random.randint(0, 2 ** 60)) | ||||
|         batch_size = ctx.shape[0] | ||||
|         self.batch_size = batch_size | ||||
|         return self.generate_xmap( | ||||
|             self.state, | ||||
|             jnp.array(key.take(batch_size)), | ||||
|             ctx, | ||||
|             np.array(ctx_length, dtype=np.uint32), | ||||
|             np.array(gen_length, dtype=np.uint32), | ||||
|             np.empty((batch_size, numseqs), dtype=np.uint8), | ||||
|             sampler_options, | ||||
|             soft_embeddings, | ||||
|         ) | ||||
|         _numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32) | ||||
|         numseqs_aux = batch_xmap(_numseqs_aux) | ||||
|         sample_data = [ | ||||
|             [ | ||||
|                 np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id), | ||||
|                 params["seq"], | ||||
|                 None, | ||||
|                 np.empty((), dtype=np.uint32), | ||||
|             ] | ||||
|             for i in range(numseqs) | ||||
|         ] | ||||
|         repetition_penalty = sampler_options.pop("repetition_penalty", 1.0) | ||||
|         n_generated = 0 | ||||
|         regeneration_required = False | ||||
|         halt = False | ||||
|         generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings) | ||||
|         sample_key = np.asarray(sample_key[0, 0]) | ||||
|         while True: | ||||
|             generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings) | ||||
|             for i in range(numseqs): | ||||
|                 sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True) | ||||
|             if use_callback: | ||||
|                 logits = np.float32(tuple(d[2] for d in sample_data)) | ||||
|                 logits = warper_callback(logits) | ||||
|                 for i in range(numseqs): | ||||
|                     sample_data[i][2] = logits[i] | ||||
|             sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options) | ||||
|             n_generated += 1 | ||||
|             for i in range(numseqs): | ||||
|                 generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1)) | ||||
|             if use_callback: | ||||
|                 generated = np.uint32(tuple(d[0] for d in sample_data)) | ||||
|                 excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info) | ||||
|                 if regeneration_required or halt: | ||||
|                     break | ||||
|             else: | ||||
|                 break | ||||
|         return sample_data, n_generated, regeneration_required, halt | ||||
|  | ||||
|  | ||||
| def infer( | ||||
| @@ -278,35 +373,40 @@ def infer( | ||||
|     gen_len=80, | ||||
|     soft_embeddings: Optional[np.array] = None, | ||||
|     soft_tokens: Optional[np.array] = None, | ||||
| ) -> List[str]: | ||||
|     excluded_world_info = None, | ||||
|     use_callback=True, | ||||
| ) -> Tuple[List[np.array], int, bool, bool]: | ||||
|     assert excluded_world_info is not None | ||||
|     maps.thread_resources.env = thread_resources_env | ||||
|     total_batch = 1 | ||||
|     tokens = context | ||||
|     if(soft_tokens is not None): | ||||
|         tokens = np.uint32(np.concatenate((soft_tokens, tokens))) | ||||
|     provided_ctx = tokens.shape[0] | ||||
|         tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1)) | ||||
|     provided_ctx = tokens.shape[-1] | ||||
|     pad_amount = seq - provided_ctx | ||||
|     padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) | ||||
|     padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id) | ||||
|     batched_tokens = np.array([padded_tokens] * total_batch) | ||||
|     samples = [] | ||||
|     batched_generator_params = { | ||||
|         "temp": temp * np.ones(total_batch), | ||||
|         "top_p": top_p * np.ones(total_batch), | ||||
|         "tfs": tfs * np.ones(total_batch), | ||||
|         "repetition_penalty": repetition_penalty * np.ones(total_batch), | ||||
|         "top_k": np.full(total_batch, top_k, dtype=np.uint32) | ||||
|     generator_params = { | ||||
|         "temp": float(temp), | ||||
|         "top_p": float(top_p), | ||||
|         "tfs": float(tfs), | ||||
|         "repetition_penalty": float(repetition_penalty), | ||||
|         "top_k": int(top_k), | ||||
|     } | ||||
|     output = network.generate( | ||||
|         batched_tokens, | ||||
|         np.ones(total_batch, dtype=np.uint32) * provided_ctx, | ||||
|         np.ones(total_batch, dtype=np.uint32) * gen_len, | ||||
|         numseqs, | ||||
|         batched_generator_params, | ||||
|         generator_params, | ||||
|         soft_embeddings=soft_embeddings, | ||||
|     )[0] | ||||
|     for o in output: | ||||
|         samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) | ||||
|     return samples | ||||
|         excluded_world_info=excluded_world_info, | ||||
|         use_callback=use_callback, | ||||
|     ) | ||||
|     for out in output[0]: | ||||
|         samples.append(out[0][params["seq"] : params["seq"] + gen_len]) | ||||
|     return (samples,) + output[1:] | ||||
|  | ||||
|  | ||||
| def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: | ||||
| @@ -354,6 +454,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) | ||||
|     maps.thread_resources.env = thread_resources_env | ||||
|     tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') | ||||
|  | ||||
|     global shard_xmap, batch_xmap | ||||
|     shard_xmap = __shard_xmap() | ||||
|     batch_xmap = __batch_xmap(shard_dim=cores_per_replica) | ||||
|  | ||||
|     global badwords | ||||
|     # These are the tokens that we don't want the AI to ever write | ||||
|     badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) | ||||
|  | ||||
|     if not path.endswith("/"): | ||||
|         path += "/" | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user