mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Merge pull request #61 from VE-FORBRYDERNE/xmap
Use original TPU backend when possible
This commit is contained in:
		
							
								
								
									
										84
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -157,6 +157,7 @@ class vars: | |||||||
|     spmeta      = None   # Metadata of current soft prompt, or None if not using a soft prompt |     spmeta      = None   # Metadata of current soft prompt, or None if not using a soft prompt | ||||||
|     sp          = None   # Current soft prompt tensor (as a NumPy array) |     sp          = None   # Current soft prompt tensor (as a NumPy array) | ||||||
|     sp_length   = 0      # Length of current soft prompt in tokens, or 0 if not using a soft prompt |     sp_length   = 0      # Length of current soft prompt in tokens, or 0 if not using a soft prompt | ||||||
|  |     has_genmod  = False  # Whether or not at least one loaded Lua userscript has a generation modifier | ||||||
|     svowname    = ""     # Filename that was flagged for overwrite confirm |     svowname    = ""     # Filename that was flagged for overwrite confirm | ||||||
|     saveow      = False  # Whether or not overwrite confirm has been displayed |     saveow      = False  # Whether or not overwrite confirm has been displayed | ||||||
|     genseqs     = []     # Temporary storage for generated sequences |     genseqs     = []     # Temporary storage for generated sequences | ||||||
| @@ -184,6 +185,7 @@ class vars: | |||||||
|     remote      = False |     remote      = False | ||||||
|     nopromptgen = False |     nopromptgen = False | ||||||
|     rngpersist  = False |     rngpersist  = False | ||||||
|  |     nogenmod    = False | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| # Function to get model selection at startup | # Function to get model selection at startup | ||||||
| @@ -1070,19 +1072,6 @@ else: | |||||||
|         vars.allowsp = True |         vars.allowsp = True | ||||||
|         vars.modeldim = int(tpu_mtj_backend.params["d_model"]) |         vars.modeldim = int(tpu_mtj_backend.params["d_model"]) | ||||||
|         tokenizer = tpu_mtj_backend.tokenizer |         tokenizer = tpu_mtj_backend.tokenizer | ||||||
|         soft_tokens = tpumtjgetsofttokens() |  | ||||||
|         threading.Thread(  # Compile backend code in background |  | ||||||
|             target=tpu_mtj_backend.infer, |  | ||||||
|             args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), |  | ||||||
|             kwargs={ |  | ||||||
|                 "soft_embeddings": vars.sp, |  | ||||||
|                 "soft_tokens": soft_tokens, |  | ||||||
|                 "use_callback": False, |  | ||||||
|                 "gen_len": 1, |  | ||||||
|                 "numseqs": vars.numseqs, |  | ||||||
|                 "excluded_world_info": list(set() for _ in range(vars.numseqs)), |  | ||||||
|             }, |  | ||||||
|         ).start() |  | ||||||
|  |  | ||||||
| # Set up Flask routes | # Set up Flask routes | ||||||
| @app.route('/') | @app.route('/') | ||||||
| @@ -1190,13 +1179,18 @@ def load_lua_scripts(): | |||||||
|             modulenames.append(lst[i]["modulename"]) |             modulenames.append(lst[i]["modulename"]) | ||||||
|             descriptions.append(lst[i]["description"]) |             descriptions.append(lst[i]["description"]) | ||||||
|  |  | ||||||
|  |     vars.has_genmod = False | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         vars.lua_koboldbridge.obliterate_multiverse() |         vars.lua_koboldbridge.obliterate_multiverse() | ||||||
|         tpool.execute(vars.lua_koboldbridge.load_corescript, vars.corescript) |         tpool.execute(vars.lua_koboldbridge.load_corescript, vars.corescript) | ||||||
|         tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions) |         vars.has_genmod = tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions) | ||||||
|         vars.lua_running = True |         vars.lua_running = True | ||||||
|     except lupa.LuaError as e: |     except lupa.LuaError as e: | ||||||
|  |         try: | ||||||
|             vars.lua_koboldbridge.obliterate_multiverse() |             vars.lua_koboldbridge.obliterate_multiverse() | ||||||
|  |         except: | ||||||
|  |             pass | ||||||
|         vars.lua_running = False |         vars.lua_running = False | ||||||
|         if(vars.serverstarted): |         if(vars.serverstarted): | ||||||
|             emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error, please check console.'}, broadcast=True) |             emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error, please check console.'}, broadcast=True) | ||||||
| @@ -2043,6 +2037,10 @@ def get_message(msg): | |||||||
|         vars.rngpersist = msg['data'] |         vars.rngpersist = msg['data'] | ||||||
|         settingschanged() |         settingschanged() | ||||||
|         refresh_settings() |         refresh_settings() | ||||||
|  |     elif(msg['cmd'] == 'setnogenmod'): | ||||||
|  |         vars.nogenmod = msg['data'] | ||||||
|  |         settingschanged() | ||||||
|  |         refresh_settings() | ||||||
|     elif(not vars.remote and msg['cmd'] == 'importwi'): |     elif(not vars.remote and msg['cmd'] == 'importwi'): | ||||||
|         wiimportrequest() |         wiimportrequest() | ||||||
|  |  | ||||||
| @@ -2113,6 +2111,8 @@ def savesettings(): | |||||||
|     js["dynamicscan"] = vars.dynamicscan |     js["dynamicscan"] = vars.dynamicscan | ||||||
|     js["nopromptgen"] = vars.nopromptgen |     js["nopromptgen"] = vars.nopromptgen | ||||||
|     js["rngpersist"]  = vars.rngpersist |     js["rngpersist"]  = vars.rngpersist | ||||||
|  |     js["nogenmod"]    = vars.nogenmod | ||||||
|  |  | ||||||
|     js["antemplate"]  = vars.setauthornotetemplate |     js["antemplate"]  = vars.setauthornotetemplate | ||||||
|  |  | ||||||
|     js["userscripts"] = vars.userscripts |     js["userscripts"] = vars.userscripts | ||||||
| @@ -2178,6 +2178,8 @@ def loadsettings(): | |||||||
|             vars.nopromptgen = js["nopromptgen"] |             vars.nopromptgen = js["nopromptgen"] | ||||||
|         if("rngpersist" in js): |         if("rngpersist" in js): | ||||||
|             vars.rngpersist = js["rngpersist"] |             vars.rngpersist = js["rngpersist"] | ||||||
|  |         if("nogenmod" in js): | ||||||
|  |             vars.nogenmod = js["nogenmod"] | ||||||
|  |  | ||||||
|         if("antemplate" in js): |         if("antemplate" in js): | ||||||
|             vars.setauthornotetemplate = js["antemplate"] |             vars.setauthornotetemplate = js["antemplate"] | ||||||
| @@ -2960,15 +2962,18 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | |||||||
|  |  | ||||||
|     # Submit input text to generator |     # Submit input text to generator | ||||||
|     try: |     try: | ||||||
|         context = np.tile(np.uint32(txt), (vars.numseqs, 1)) |  | ||||||
|         soft_tokens = tpumtjgetsofttokens() |         soft_tokens = tpumtjgetsofttokens() | ||||||
|  |  | ||||||
|         global past |         global past | ||||||
|  |  | ||||||
|  |         if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): | ||||||
|  |  | ||||||
|  |             context = np.tile(np.uint32(txt), (vars.numseqs, 1)) | ||||||
|             past = np.empty((vars.numseqs, 0), dtype=np.uint32) |             past = np.empty((vars.numseqs, 0), dtype=np.uint32) | ||||||
|  |  | ||||||
|             while(True): |             while(True): | ||||||
|                 genout, n_generated, regeneration_required, halt = tpool.execute( |                 genout, n_generated, regeneration_required, halt = tpool.execute( | ||||||
|                 tpu_mtj_backend.infer, |                     tpu_mtj_backend.infer_dynamic, | ||||||
|                     context, |                     context, | ||||||
|                     gen_len = maximum-minimum+1, |                     gen_len = maximum-minimum+1, | ||||||
|                     temp=vars.temp, |                     temp=vars.temp, | ||||||
| @@ -3009,6 +3014,24 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None): | |||||||
|                     axis=-1, |                     axis=-1, | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             genout = tpool.execute( | ||||||
|  |                 tpu_mtj_backend.infer_static, | ||||||
|  |                 np.uint32(txt), | ||||||
|  |                 gen_len = maximum-minimum+1, | ||||||
|  |                 temp=vars.temp, | ||||||
|  |                 top_p=vars.top_p, | ||||||
|  |                 top_k=vars.top_k, | ||||||
|  |                 tfs=vars.tfs, | ||||||
|  |                 numseqs=vars.numseqs, | ||||||
|  |                 repetition_penalty=vars.rep_pen, | ||||||
|  |                 soft_embeddings=vars.sp, | ||||||
|  |                 soft_tokens=soft_tokens, | ||||||
|  |             ) | ||||||
|  |             past = genout | ||||||
|  |             for i in range(vars.numseqs): | ||||||
|  |                 vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist()) | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         if(issubclass(type(e), lupa.LuaError)): |         if(issubclass(type(e), lupa.LuaError)): | ||||||
|             vars.lua_koboldbridge.obliterate_multiverse() |             vars.lua_koboldbridge.obliterate_multiverse() | ||||||
| @@ -3189,6 +3212,7 @@ def refresh_settings(): | |||||||
|     emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True) |     emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True) | ||||||
|     emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) |     emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True) | ||||||
|     emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) |     emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True) | ||||||
|  |     emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True) | ||||||
|      |      | ||||||
|     emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) |     emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True) | ||||||
|     emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) |     emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True) | ||||||
| @@ -4442,6 +4466,34 @@ def randomGameRequest(topic, memory=""): | |||||||
| loadmodelsettings() | loadmodelsettings() | ||||||
| loadsettings() | loadsettings() | ||||||
|  |  | ||||||
|  | # Precompile TPU backend if required | ||||||
|  | if(vars.model in ("TPUMeshTransformerGPTJ",)): | ||||||
|  |     soft_tokens = tpumtjgetsofttokens() | ||||||
|  |     if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)): | ||||||
|  |         threading.Thread( | ||||||
|  |             target=tpu_mtj_backend.infer_dynamic, | ||||||
|  |             args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),), | ||||||
|  |             kwargs={ | ||||||
|  |                 "soft_embeddings": vars.sp, | ||||||
|  |                 "soft_tokens": soft_tokens, | ||||||
|  |                 "gen_len": 1, | ||||||
|  |                 "use_callback": False, | ||||||
|  |                 "numseqs": vars.numseqs, | ||||||
|  |                 "excluded_world_info": list(set() for _ in range(vars.numseqs)), | ||||||
|  |             }, | ||||||
|  |         ).start() | ||||||
|  |     else: | ||||||
|  |         threading.Thread( | ||||||
|  |             target=tpu_mtj_backend.infer_static, | ||||||
|  |             args=(np.uint32((23403, 727, 20185)),), | ||||||
|  |             kwargs={ | ||||||
|  |                 "soft_embeddings": vars.sp, | ||||||
|  |                 "soft_tokens": soft_tokens, | ||||||
|  |                 "gen_len": 1, | ||||||
|  |                 "numseqs": vars.numseqs, | ||||||
|  |             }, | ||||||
|  |         ).start() | ||||||
|  |  | ||||||
| #==================================================================# | #==================================================================# | ||||||
| #  Final startup commands to launch Flask app | #  Final startup commands to launch Flask app | ||||||
| #==================================================================# | #==================================================================# | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								bridge.lua
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								bridge.lua
									
									
									
									
									
								
							| @@ -1851,13 +1851,14 @@ return function(_python, _bridged) | |||||||
|     -- API for aiserver.py |     -- API for aiserver.py | ||||||
|     --========================================================================== |     --========================================================================== | ||||||
|  |  | ||||||
|     ---@return nil |     ---@return boolean | ||||||
|     function koboldbridge.load_userscripts(filenames, modulenames, descriptions) |     function koboldbridge.load_userscripts(filenames, modulenames, descriptions) | ||||||
|         config_files = {} |         config_files = {} | ||||||
|         config_file_filename_map = {} |         config_file_filename_map = {} | ||||||
|         koboldbridge.userscripts = {} |         koboldbridge.userscripts = {} | ||||||
|         koboldbridge.userscriptmodule_filename_map = {} |         koboldbridge.userscriptmodule_filename_map = {} | ||||||
|         koboldbridge.num_userscripts = 0 |         koboldbridge.num_userscripts = 0 | ||||||
|  |         local has_genmod = false | ||||||
|         for i, filename in _python.enumerate(filenames) do |         for i, filename in _python.enumerate(filenames) do | ||||||
|             bridged.load_callback(filename, modulenames[i]) |             bridged.load_callback(filename, modulenames[i]) | ||||||
|             koboldbridge.logging_name = modulenames[i] |             koboldbridge.logging_name = modulenames[i] | ||||||
| @@ -1865,12 +1866,15 @@ return function(_python, _bridged) | |||||||
|             local f, err = old_loadfile(join_folder_and_filename(bridged.userscript_path, filename), "t", koboldbridge.get_universe(filename)) |             local f, err = old_loadfile(join_folder_and_filename(bridged.userscript_path, filename), "t", koboldbridge.get_universe(filename)) | ||||||
|             if err ~= nil then |             if err ~= nil then | ||||||
|                 error(err) |                 error(err) | ||||||
|                 return |                 return false | ||||||
|             end |             end | ||||||
|             ---@type KoboldUserScript |             ---@type KoboldUserScript | ||||||
|             local _userscript = f() |             local _userscript = f() | ||||||
|             koboldbridge.logging_name = nil |             koboldbridge.logging_name = nil | ||||||
|             koboldbridge.filename = nil |             koboldbridge.filename = nil | ||||||
|  |             if _userscript.genmod ~= nil then | ||||||
|  |                 has_genmod = true | ||||||
|  |             end | ||||||
|             local userscript = deepcopy(KoboldUserScriptModule) |             local userscript = deepcopy(KoboldUserScriptModule) | ||||||
|             rawset(userscript, "_inmod", function() |             rawset(userscript, "_inmod", function() | ||||||
|                 koboldbridge.logging_name = modulenames[i] |                 koboldbridge.logging_name = modulenames[i] | ||||||
| @@ -1903,6 +1907,7 @@ return function(_python, _bridged) | |||||||
|             koboldbridge.userscriptmodule_filename_map[userscript] = filename |             koboldbridge.userscriptmodule_filename_map[userscript] = filename | ||||||
|             koboldbridge.num_userscripts = i + 1 |             koboldbridge.num_userscripts = i + 1 | ||||||
|         end |         end | ||||||
|  |         return has_genmod | ||||||
|     end |     end | ||||||
|  |  | ||||||
|     ---@return nil |     ---@return nil | ||||||
| @@ -1949,7 +1954,9 @@ return function(_python, _bridged) | |||||||
|         koboldbridge.userstate = "genmod" |         koboldbridge.userstate = "genmod" | ||||||
|         if koboldbridge.genmod ~= nil then |         if koboldbridge.genmod ~= nil then | ||||||
|             local _generated = deepcopy(koboldbridge.generated) |             local _generated = deepcopy(koboldbridge.generated) | ||||||
|  |             if not bridged.vars.nogenmod then | ||||||
|                 r = koboldbridge.genmod() |                 r = koboldbridge.genmod() | ||||||
|  |             end | ||||||
|             setmetatable(koboldbridge.logits, nil) |             setmetatable(koboldbridge.logits, nil) | ||||||
|             for kr, vr in old_next, koboldbridge.logits, nil do |             for kr, vr in old_next, koboldbridge.logits, nil do | ||||||
|                 setmetatable(vr, nil) |                 setmetatable(vr, nil) | ||||||
|   | |||||||
| @@ -162,6 +162,17 @@ gensettingstf = [{ | |||||||
| 	"step": 1, | 	"step": 1, | ||||||
| 	"default": 0, | 	"default": 0, | ||||||
|     "tooltip": "When enabled, the Memory text box in the Random Story dialog will be prefilled by default with your current story's memory instead of being empty." |     "tooltip": "When enabled, the Memory text box in the Random Story dialog will be prefilled by default with your current story's memory instead of being empty." | ||||||
|  | 	}, | ||||||
|  | 	{ | ||||||
|  | 	"uitype": "toggle", | ||||||
|  | 	"unit": "bool", | ||||||
|  | 	"label": "No Genmod", | ||||||
|  | 	"id": "setnogenmod", | ||||||
|  | 	"min": 0, | ||||||
|  | 	"max": 1, | ||||||
|  | 	"step": 1, | ||||||
|  | 	"default": 0, | ||||||
|  |     "tooltip": "Disables userscript generation modifiers." | ||||||
| 	}] | 	}] | ||||||
|  |  | ||||||
| gensettingsik =[{ | gensettingsik =[{ | ||||||
|   | |||||||
| @@ -2194,6 +2194,9 @@ $(document).ready(function(){ | |||||||
| 			if(!$("#setrngpersist").prop("checked")) { | 			if(!$("#setrngpersist").prop("checked")) { | ||||||
| 				$("#rngmemory").val(""); | 				$("#rngmemory").val(""); | ||||||
| 			} | 			} | ||||||
|  | 		} else if(msg.cmd == "updatenogenmod") { | ||||||
|  | 			// Update toggle state | ||||||
|  | 			$("#setnogenmod").prop('checked', msg.data).change(); | ||||||
| 		} else if(msg.cmd == "runs_remotely") { | 		} else if(msg.cmd == "runs_remotely") { | ||||||
| 			remote = true; | 			remote = true; | ||||||
| 			hide([button_savetofile, button_import, button_importwi]); | 			hide([button_savetofile, button_import, button_importwi]); | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ | |||||||
| 	<script src="static/bootstrap.min.js"></script> | 	<script src="static/bootstrap.min.js"></script> | ||||||
| 	<script src="static/bootstrap-toggle.min.js"></script> | 	<script src="static/bootstrap-toggle.min.js"></script> | ||||||
| 	<script src="static/rangy-core.min.js"></script> | 	<script src="static/rangy-core.min.js"></script> | ||||||
| 	<script src="static/application.js?ver=1.16.4v"></script> | 	<script src="static/application.js?ver=1.16.4w"></script> | ||||||
| </head> | </head> | ||||||
| <body> | <body> | ||||||
| 	<input type="file" id="remote-save-select" accept="application/json" style="display:none"> | 	<input type="file" id="remote-save-select" accept="application/json" style="display:none"> | ||||||
|   | |||||||
| @@ -60,7 +60,7 @@ def __batch_xmap(shard_dim=1): | |||||||
|     return inner |     return inner | ||||||
|  |  | ||||||
|  |  | ||||||
| def apply_repetition_penalty(logits, tokens, repetition_penalty): | def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty): | ||||||
|     ''' |     ''' | ||||||
|     This gets called by generate_loop_fn to apply repetition penalty |     This gets called by generate_loop_fn to apply repetition penalty | ||||||
|     to the 1D array logits using the provided 1D array of tokens to penalize |     to the 1D array logits using the provided 1D array of tokens to penalize | ||||||
| @@ -85,7 +85,7 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty): | |||||||
|     logits[tokens] = penalty_logits |     logits[tokens] = penalty_logits | ||||||
|     return logits |     return logits | ||||||
|  |  | ||||||
| def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | def kobold_sample_dynamic(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||||
|     ''' |     ''' | ||||||
|     This gets called by generate_loop_fn to apply a series of 4 filters |     This gets called by generate_loop_fn to apply a series of 4 filters | ||||||
|     to the logits (top-k, then top-p, then TFS, then temperature) before |     to the logits (top-k, then top-p, then TFS, then temperature) before | ||||||
| @@ -183,6 +183,127 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | |||||||
|     # probability distribution) |     # probability distribution) | ||||||
|     return jax.random.categorical(key, logits, -1).astype(np.uint32) |     return jax.random.categorical(key, logits, -1).astype(np.uint32) | ||||||
|  |  | ||||||
|  | def apply_repetition_penalty_static(logits, tokens, repetition_penalty): | ||||||
|  |     ''' | ||||||
|  |     This gets called by generate_loop_fn to apply repetition penalty | ||||||
|  |     to the 1D array logits using the provided 1D array of tokens to penalize | ||||||
|  |     ''' | ||||||
|  |     # Make a new array with the same length as the tokens array but with | ||||||
|  |     # each element replaced by the value at the corresponding index in the | ||||||
|  |     # logits array; e.g. | ||||||
|  |     # if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1], | ||||||
|  |     # then penalty_logits will be [77, 5, 3, 98, 3, 98, 5] | ||||||
|  |     penalty_logits = jnp.take(logits, tokens) | ||||||
|  |     # Divide positive values by repetition_penalty and multiply negative | ||||||
|  |     # values by repetition_penalty (the academic publication that described | ||||||
|  |     # this technique actually just only divided, but that would cause tokens | ||||||
|  |     # with negative logits to become more likely, which is obviously wrong) | ||||||
|  |     penalty_logits = jnp.where( | ||||||
|  |         penalty_logits > 0, | ||||||
|  |         penalty_logits/repetition_penalty, | ||||||
|  |         penalty_logits*repetition_penalty, | ||||||
|  |     ) | ||||||
|  |     # Finally, put those penalized logit values back into their original | ||||||
|  |     # positions in the logits array | ||||||
|  |     return logits.at[tokens].set(penalty_logits) | ||||||
|  |  | ||||||
|  | def kobold_sample_static(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0): | ||||||
|  |     ''' | ||||||
|  |     This gets called by generate_loop_fn to apply a series of 4 filters | ||||||
|  |     to the logits (top-k, then top-p, then TFS, then temperature) before | ||||||
|  |     picking one token using the modified logits | ||||||
|  |     ''' | ||||||
|  |     # Top-k (keep only the k tokens with the highest logits and remove | ||||||
|  |     # the rest, by setting their logits to negative infinity) | ||||||
|  |     def top_k_filter(logits): | ||||||
|  |         # After sorting the logits array in descending order, | ||||||
|  |         # sorted_indices_to_remove is a 1D array that is True for tokens | ||||||
|  |         # in the sorted logits array we want to remove and False for ones | ||||||
|  |         # we want to keep, in this case the first top_k elements will be | ||||||
|  |         # False and the rest will be True | ||||||
|  |         sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k | ||||||
|  |         # Unsort the logits array back to its original configuration and | ||||||
|  |         # remove tokens we need to remove | ||||||
|  |         _, indices_to_remove = jax.lax.sort_key_val( | ||||||
|  |             jnp.argsort(-logits), | ||||||
|  |             sorted_indices_to_remove, | ||||||
|  |         ) | ||||||
|  |         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||||
|  |     logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits) | ||||||
|  |     # Top-p (after sorting the remaining tokens again in descending order of | ||||||
|  |     # logit, remove the ones that have cumulative softmax probability | ||||||
|  |     # greater than p) | ||||||
|  |     def top_p_filter(logits): | ||||||
|  |         # Sort the logits array in descending order, replace every element | ||||||
|  |         # with e (Euler's number) to the power of that element, and divide | ||||||
|  |         # each element of the new array by the sum of the elements in the | ||||||
|  |         # new array | ||||||
|  |         sorted_logits = -jnp.sort(-logits) | ||||||
|  |         probabilities = jax.nn.softmax(sorted_logits) | ||||||
|  |         # Calculate cumulative_probabilities as the prefix-sum array of | ||||||
|  |         # probabilities | ||||||
|  |         cumulative_probabilities = jnp.cumsum(probabilities, axis=-1) | ||||||
|  |         # We want to remove tokens with cumulative probability higher | ||||||
|  |         # than top_p | ||||||
|  |         sorted_indices_to_remove = cumulative_probabilities > top_p | ||||||
|  |         # Don't ever remove the token with the highest logit, even if | ||||||
|  |         # the probability is higher than top_p | ||||||
|  |         sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) | ||||||
|  |         # Unsort and remove | ||||||
|  |         _, indices_to_remove = jax.lax.sort_key_val( | ||||||
|  |             jnp.argsort(-logits), | ||||||
|  |             sorted_indices_to_remove, | ||||||
|  |         ) | ||||||
|  |         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||||
|  |     logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits) | ||||||
|  |     # Tail free sampling (basically top-p a second time on remaining tokens | ||||||
|  |     # except it's the "cumulative normalized absolute second finite | ||||||
|  |     # differences of the softmax probabilities" instead of just the | ||||||
|  |     # cumulative softmax probabilities) | ||||||
|  |     def tail_free_filter(logits): | ||||||
|  |         # Sort in descending order | ||||||
|  |         sorted_logits = -jnp.sort(-logits) | ||||||
|  |         # Softmax again | ||||||
|  |         probabilities = jax.nn.softmax(sorted_logits) | ||||||
|  |         # Calculate the second finite differences of that array (i.e. | ||||||
|  |         # calculate the difference array and then calculate the difference | ||||||
|  |         # array of the difference array) | ||||||
|  |         d2 = jnp.diff(jnp.diff(probabilities)) | ||||||
|  |         # Get the absolute values of all those second finite differences | ||||||
|  |         d2 = jnp.abs(d2) | ||||||
|  |         # Normalize (all elements in the array are divided by the sum of the | ||||||
|  |         # array's elements) | ||||||
|  |         d2 = d2 / d2.sum(axis=-1, keepdims=True) | ||||||
|  |         # Get the prefix-sum array | ||||||
|  |         cumulative_d2 = jnp.cumsum(d2, axis=-1) | ||||||
|  |         # We will remove the tokens with a cumulative normalized absolute | ||||||
|  |         # second finite difference larger than the TFS value | ||||||
|  |         sorted_indices_to_remove = cumulative_d2 > tfs | ||||||
|  |         # Don't remove the token with the highest logit | ||||||
|  |         sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False) | ||||||
|  |         # Since the d2 array has two fewer elements than the logits array, | ||||||
|  |         # we'll add two extra Trues to the end | ||||||
|  |         sorted_indices_to_remove = jnp.pad( | ||||||
|  |             sorted_indices_to_remove, | ||||||
|  |             (0, 2), | ||||||
|  |             constant_values=True, | ||||||
|  |         ) | ||||||
|  |         # Unsort and remove | ||||||
|  |         _, indices_to_remove = jax.lax.sort_key_val( | ||||||
|  |             jnp.argsort(-logits), | ||||||
|  |             sorted_indices_to_remove, | ||||||
|  |         ) | ||||||
|  |         return jnp.where(indices_to_remove, -jnp.inf, logits) | ||||||
|  |     logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits) | ||||||
|  |     # Temperature (just divide the logits by the temperature) | ||||||
|  |     def temp_filter(logits): | ||||||
|  |         return logits / temp | ||||||
|  |     logits = jax.lax.cond(True, temp_filter, lambda x: x, logits) | ||||||
|  |     # Finally, pick one token using the softmax thingy again (it gives | ||||||
|  |     # an array whose elements sum to 1 so it can be used nicely as a | ||||||
|  |     # probability distribution) | ||||||
|  |     return jax.random.categorical(key, logits, -1).astype(jnp.uint32) | ||||||
|  |  | ||||||
| pad_token_id = 50256 | pad_token_id = 50256 | ||||||
|  |  | ||||||
| def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): | def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options): | ||||||
| @@ -192,11 +313,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op | |||||||
|         generated, generated_index, logits, _ = carry[0][0] |         generated, generated_index, logits, _ = carry[0][0] | ||||||
|         sample_key = carry[1] |         sample_key = carry[1] | ||||||
|         # Get the pseudo-random number generator key that will |         # Get the pseudo-random number generator key that will | ||||||
|         # be used by kobold_sample to randomly pick a token |         # be used by kobold_sample_dynamic to randomly pick a token | ||||||
|         sample_key, new_key = jax.random.split(sample_key, num=2) |         sample_key, new_key = jax.random.split(sample_key, num=2) | ||||||
|         # Apply repetition penalty to all tokens that are |         # Apply repetition penalty to all tokens that are | ||||||
|         # currently inside the "generated" array |         # currently inside the "generated" array | ||||||
|         logits = apply_repetition_penalty( |         logits = apply_repetition_penalty_dynamic( | ||||||
|             logits, |             logits, | ||||||
|             generated, |             generated, | ||||||
|             repetition_penalty |             repetition_penalty | ||||||
| @@ -205,11 +326,11 @@ def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_op | |||||||
|         # their logits to negative infinity which effectively |         # their logits to negative infinity which effectively | ||||||
|         # makes their probabilities of being chosen zero |         # makes their probabilities of being chosen zero | ||||||
|         logits[badwords] = -np.inf |         logits[badwords] = -np.inf | ||||||
|         # Use the sampler (kobold_sample) to pick one token |         # Use the sampler (kobold_sample_dynamic) to pick one token | ||||||
|         # based on the logits array as a 0D uint32 array |         # based on the logits array as a 0D uint32 array | ||||||
|         # (higher logit means higher probability of being |         # (higher logit means higher probability of being | ||||||
|         # picked, non-linearly) |         # picked, non-linearly) | ||||||
|         next_token = kobold_sample( |         next_token = kobold_sample_dynamic( | ||||||
|             sample_key, |             sample_key, | ||||||
|             logits, |             logits, | ||||||
|             **sampler_options, |             **sampler_options, | ||||||
| @@ -236,6 +357,100 @@ class PenalizingCausalTransformer(CausalTransformer): | |||||||
|     def __init__(self, config): |     def __init__(self, config): | ||||||
|         # Initialize |         # Initialize | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
|  |         def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None): | ||||||
|  |             numseqs = numseqs_aux.shape[0] | ||||||
|  |             # These are the tokens that we don't want the AI to ever write | ||||||
|  |             self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146]) | ||||||
|  |             @hk.transform | ||||||
|  |             def generate_sample(context, ctx_length): | ||||||
|  |                 # Give the initial context to the transformer | ||||||
|  |                 transformer = CausalTransformerShard(config) | ||||||
|  |                 def generate_initial_scan_fn(sequence_index, _): | ||||||
|  |                     _, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings) | ||||||
|  |                     # The "generated" array will contain the tokens from the | ||||||
|  |                     # context as well as the tokens picked by the sampler at | ||||||
|  |                     # each stage, padded with a bunch of 50256s, so we know | ||||||
|  |                     # which tokens have to be repetition penalized | ||||||
|  |                     generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id)  # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens | ||||||
|  |                     generated_index = config["seq"] | ||||||
|  |                     # Add that information to generate_loop_fn's starting state | ||||||
|  |                     initial_state = (generated, generated_index, sequence_index) + initial_state | ||||||
|  |                     return sequence_index+1, initial_state | ||||||
|  |                 _, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs) | ||||||
|  |                 sample_key = initial_states[-1][0] | ||||||
|  |                 initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs)) | ||||||
|  |                 # Get repetition penalty from the arguments | ||||||
|  |                 repetition_penalty = sampler_options.pop('repetition_penalty', None) | ||||||
|  |                 # This is the main generation loop | ||||||
|  |                 def generate_loop_fn(carry): | ||||||
|  |                     # Unpack current generate_loop_fn state | ||||||
|  |                     generated, generated_index, sequence_index, next_token, decode_state = carry[0][0] | ||||||
|  |                     sample_key = carry[1] | ||||||
|  |                     # Get the pseudo-random number generator key that will | ||||||
|  |                     # be used by kobold_sample_static to randomly pick a token | ||||||
|  |                     sample_key, new_key = jax.random.split(sample_key) | ||||||
|  |                     # Give the context to the model and get the logits it | ||||||
|  |                     # spits out | ||||||
|  |                     # (a 2D array with 1 row and 50400 columns representing | ||||||
|  |                     # how strongly it thinks each of the 50257 tokens in its | ||||||
|  |                     # vocabulary should be appended to the context, followed | ||||||
|  |                     # by 143 apparently useless columns ???) | ||||||
|  |                     logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings) | ||||||
|  |                     # Verify that logits does indeed have that many rows and | ||||||
|  |                     # columns (if you get an error here, pray for mercy) | ||||||
|  |                     assert logits.shape == (1, config["n_vocab"]) | ||||||
|  |                     # Flatten it into a 1D array to make it easier to use | ||||||
|  |                     logits = logits[0] | ||||||
|  |                     # Apply repetition penalty to all tokens that are | ||||||
|  |                     # currently inside the "generated" array | ||||||
|  |                     if repetition_penalty is not None: | ||||||
|  |                         logits = apply_repetition_penalty_static( | ||||||
|  |                             logits, | ||||||
|  |                             generated, | ||||||
|  |                             repetition_penalty | ||||||
|  |                         ) | ||||||
|  |                     # Remove any tokens in the badwords list by setting | ||||||
|  |                     # their logits to negative infinity which effectively | ||||||
|  |                     # makes their probabilities of being chosen zero | ||||||
|  |                     logits = logits.at[self.badwords].set(-jnp.inf) | ||||||
|  |                     # Use the sampler (kobold_sample_static) to pick one token | ||||||
|  |                     # based on the logits array as a 0D uint32 array | ||||||
|  |                     # (higher logit means higher probability of being | ||||||
|  |                     # picked, non-linearly) | ||||||
|  |                     next_token = kobold_sample_static( | ||||||
|  |                         sample_key, | ||||||
|  |                         logits, | ||||||
|  |                         **sampler_options, | ||||||
|  |                     ) | ||||||
|  |                     # Remember what token was picked | ||||||
|  |                     generated = generated.at[generated_index].set(next_token) | ||||||
|  |                     generated_index += 1 | ||||||
|  |                     # Re-pack the current generate_loop_fn's state so we can | ||||||
|  |                     # get back the same variables the next time | ||||||
|  |                     carry[0][0] = (generated, generated_index, sequence_index, next_token[jnp.newaxis], new_state) | ||||||
|  |                     carry[0].append(carry[0].pop(0)) | ||||||
|  |                     return carry[0], new_key | ||||||
|  |                 return jax.lax.while_loop( | ||||||
|  |                     lambda carry: carry[0][0][1] - config["seq"] < gen_length, | ||||||
|  |                     generate_loop_fn, | ||||||
|  |                     (initial_states, sample_key), | ||||||
|  |                 ) | ||||||
|  |             return generate_sample.apply(state["params"], key, ctx, ctx_length) | ||||||
|  |         self.generate_static_xmap = jax.experimental.maps.xmap( | ||||||
|  |             fun=generate_static, | ||||||
|  |             in_axes=( | ||||||
|  |                 ["shard", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["batch", ...], | ||||||
|  |                 ["shard", ...], | ||||||
|  |             ), | ||||||
|  |             out_axes=["shard", "batch", ...], | ||||||
|  |             axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||||
|  |         ) | ||||||
|         def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): |         def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None): | ||||||
|             numseqs = numseqs_aux.shape[0] |             numseqs = numseqs_aux.shape[0] | ||||||
|             @hk.transform |             @hk.transform | ||||||
| @@ -314,7 +529,7 @@ class PenalizingCausalTransformer(CausalTransformer): | |||||||
|             out_axes=["shard", "batch", ...], |             out_axes=["shard", "batch", ...], | ||||||
|             axis_resources={'shard': 'mp', 'batch': 'dp'}, |             axis_resources={'shard': 'mp', 'batch': 'dp'}, | ||||||
|         ) |         ) | ||||||
|     def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): |     def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True): | ||||||
|         assert excluded_world_info is not None |         assert excluded_world_info is not None | ||||||
|         assert not return_logits |         assert not return_logits | ||||||
|         assert gen_length.ndim == 1 |         assert gen_length.ndim == 1 | ||||||
| @@ -360,9 +575,24 @@ class PenalizingCausalTransformer(CausalTransformer): | |||||||
|             else: |             else: | ||||||
|                 break |                 break | ||||||
|         return sample_data, n_generated, regeneration_required, halt |         return sample_data, n_generated, regeneration_required, halt | ||||||
|  |     def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None): | ||||||
|  |         assert not return_logits | ||||||
|  |         key = hk.PRNGSequence(random.randint(0, 2 ** 60)) | ||||||
|  |         batch_size = ctx.shape[0] | ||||||
|  |         self.batch_size = batch_size | ||||||
|  |         return self.generate_static_xmap( | ||||||
|  |             self.state, | ||||||
|  |             jnp.array(key.take(batch_size)), | ||||||
|  |             ctx, | ||||||
|  |             np.array(ctx_length, dtype=np.uint32), | ||||||
|  |             np.array(gen_length, dtype=np.uint32), | ||||||
|  |             np.empty((batch_size, numseqs), dtype=np.uint8), | ||||||
|  |             sampler_options, | ||||||
|  |             soft_embeddings, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def infer( | def infer_dynamic( | ||||||
|     context: np.array, |     context: np.array, | ||||||
|     top_p=0.9, |     top_p=0.9, | ||||||
|     temp=0.5, |     temp=0.5, | ||||||
| @@ -394,7 +624,7 @@ def infer( | |||||||
|         "repetition_penalty": float(repetition_penalty), |         "repetition_penalty": float(repetition_penalty), | ||||||
|         "top_k": int(top_k), |         "top_k": int(top_k), | ||||||
|     } |     } | ||||||
|     output = network.generate( |     output = network.generate_dynamic( | ||||||
|         batched_tokens, |         batched_tokens, | ||||||
|         np.ones(total_batch, dtype=np.uint32) * provided_ctx, |         np.ones(total_batch, dtype=np.uint32) * provided_ctx, | ||||||
|         np.ones(total_batch, dtype=np.uint32) * gen_len, |         np.ones(total_batch, dtype=np.uint32) * gen_len, | ||||||
| @@ -408,6 +638,47 @@ def infer( | |||||||
|         samples.append(out[0][params["seq"] : params["seq"] + gen_len]) |         samples.append(out[0][params["seq"] : params["seq"] + gen_len]) | ||||||
|     return (samples,) + output[1:] |     return (samples,) + output[1:] | ||||||
|  |  | ||||||
|  | def infer_static( | ||||||
|  |     context: np.array, | ||||||
|  |     top_p=0.9, | ||||||
|  |     temp=0.5, | ||||||
|  |     top_k=0, | ||||||
|  |     tfs=1.0, | ||||||
|  |     repetition_penalty=1.0, | ||||||
|  |     numseqs=1, | ||||||
|  |     gen_len=80, | ||||||
|  |     soft_embeddings: Optional[np.array] = None, | ||||||
|  |     soft_tokens: Optional[np.array] = None, | ||||||
|  | ) -> List[np.array]: | ||||||
|  |     maps.thread_resources.env = thread_resources_env | ||||||
|  |     total_batch = 1 | ||||||
|  |     tokens = context | ||||||
|  |     if(soft_tokens is not None): | ||||||
|  |         tokens = np.uint32(np.concatenate((soft_tokens, tokens))) | ||||||
|  |     provided_ctx = tokens.shape[0] | ||||||
|  |     pad_amount = seq - provided_ctx | ||||||
|  |     padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id) | ||||||
|  |     batched_tokens = np.array([padded_tokens] * total_batch) | ||||||
|  |     samples = [] | ||||||
|  |     batched_generator_params = { | ||||||
|  |         "temp": temp * np.ones(total_batch), | ||||||
|  |         "top_p": top_p * np.ones(total_batch), | ||||||
|  |         "tfs": tfs * np.ones(total_batch), | ||||||
|  |         "repetition_penalty": repetition_penalty * np.ones(total_batch), | ||||||
|  |         "top_k": np.full(total_batch, top_k, dtype=np.uint32) | ||||||
|  |     } | ||||||
|  |     output = network.generate_static( | ||||||
|  |         batched_tokens, | ||||||
|  |         np.ones(total_batch, dtype=np.uint32) * provided_ctx, | ||||||
|  |         np.ones(total_batch, dtype=np.uint32) * gen_len, | ||||||
|  |         numseqs, | ||||||
|  |         batched_generator_params, | ||||||
|  |         soft_embeddings=soft_embeddings, | ||||||
|  |     )[0] | ||||||
|  |     for o in output: | ||||||
|  |         samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len]) | ||||||
|  |     return samples | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: | def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None: | ||||||
|     global thread_resources_env, seq, tokenizer, network, params |     global thread_resources_env, seq, tokenizer, network, params | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user