mirror of
				https://github.com/KoboldAI/KoboldAI-Client.git
				synced 2025-06-05 21:59:24 +02:00 
			
		
		
		
	Add Lua API for editing logits during generation
TPU backend not supported yet.
This commit is contained in:
		
							
								
								
									
										45
									
								
								aiserver.py
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								aiserver.py
									
									
									
									
									
								
							| @@ -104,6 +104,7 @@ class vars: | ||||
|     lua_koboldbridge = None  # `koboldbridge` from bridge.lua | ||||
|     lua_kobold  = None   # `kobold` from` bridge.lua | ||||
|     lua_koboldcore = None  # `koboldcore` from bridge.lua | ||||
|     lua_warper  = None   # Transformers logits warper controllable from Lua | ||||
|     # badwords    = []     # Array of str/chr values that should be removed from output | ||||
|     badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting | ||||
|     deletewi    = -1     # Temporary storage for index to delete | ||||
| @@ -643,6 +644,37 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|                 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | ||||
|                 scores = scores.masked_fill(indices_to_remove, self.filter_value) | ||||
|                 return scores | ||||
|          | ||||
|         class LuaLogitsWarper(LogitsWarper): | ||||
|  | ||||
|             def __init__(self): | ||||
|                 self.regeneration_required = False | ||||
|                 self.halt = False | ||||
|                 pass | ||||
|  | ||||
|             def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||||
|                 assert scores.ndim == 2 | ||||
|                 self.regeneration_required = False | ||||
|                 self.halt = False | ||||
|                 scores_shape = scores.shape | ||||
|                 scores_list = self.scores.tolist() | ||||
|                 vars.lua_koboldbridge.logits = vars.lua_state.table() | ||||
|                 for r, row in enumerate(scores_list): | ||||
|                     vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row) | ||||
|                 vars.lua_koboldbridge.vocab_size = scores_shape[-1] | ||||
|                 execute_genmod() | ||||
|                 scores = torch.tensor( | ||||
|                     tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()), | ||||
|                     device=scores.device, | ||||
|                     dtype=scores.dtype, | ||||
|                 ) | ||||
|                 assert scores.shape == scores_shape | ||||
|                 if(vars.lua_koboldbridge.regeneration_required): | ||||
|                     vars.lua_koboldbridge.regeneration_required = False | ||||
|                     self.regeneration_required = True | ||||
|                 if(not vars.lua_koboldbridge.generating): | ||||
|                     self.halt = True | ||||
|                 return scores | ||||
|  | ||||
|         def new_get_logits_warper( | ||||
|             top_k: int = None, | ||||
| @@ -660,6 +692,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|                 warper_list.append(TailFreeLogitsWarper(tfs=tfs, min_tokens_to_keep=1 + (beams > 1))) | ||||
|             if(temp is not None and temp != 1.0): | ||||
|                 warper_list.append(TemperatureLogitsWarper(temperature=temp)) | ||||
|             vars.lua_warper = LuaLogitsWarper() | ||||
|             warper_list.append(vars.lua_warper) | ||||
|             return warper_list | ||||
|          | ||||
|         def new_sample(self, *args, **kwargs): | ||||
| @@ -705,15 +739,8 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme | ||||
|             ) -> bool: | ||||
|                 assert input_ids.ndim == 2 | ||||
|                 assert len(self.excluded_world_info) == input_ids.shape[0] | ||||
|                 self.regeneration_required = False | ||||
|                 self.halt = False | ||||
|  | ||||
|                 execute_genmod() | ||||
|                 if(vars.lua_koboldbridge.regeneration_required): | ||||
|                     vars.lua_koboldbridge.regeneration_required = False | ||||
|                     self.regeneration_required = True | ||||
|                 if(not vars.lua_koboldbridge.generating): | ||||
|                     self.halt = True | ||||
|                 self.regeneration_required = vars.lua_warper.regeneration_required | ||||
|                 self.halt = vars.lua_warper.halt | ||||
|  | ||||
|                 if(not vars.dynamicscan): | ||||
|                     return False | ||||
|   | ||||
							
								
								
									
										76
									
								
								bridge.lua
									
									
									
									
									
								
							
							
						
						
									
										76
									
								
								bridge.lua
									
									
									
									
									
								
							| @@ -145,6 +145,7 @@ return function(_python, _bridged) | ||||
|     ---@field modelbackend "'readonly'"|"'api'"|"'transformers'"|"'mtj'" | ||||
|     ---@field is_custommodel boolean | ||||
|     ---@field custmodpth string | ||||
|     ---@field logits table<integer, table<integer, number>> | ||||
|     local kobold = setmetatable({}, metawrapper) | ||||
|     local KoboldLib_mt = setmetatable({}, metawrapper) | ||||
|     local KoboldLib_getters = setmetatable({}, metawrapper) | ||||
| @@ -200,6 +201,8 @@ return function(_python, _bridged) | ||||
|     koboldbridge.resend_settings_required = false | ||||
|     koboldbridge.generating = true | ||||
|     koboldbridge.userstate = "inmod" | ||||
|     koboldbridge.logits = {} | ||||
|     koboldbridge.vocab_size = 0 | ||||
|  | ||||
|     ---@return nil | ||||
|     local function maybe_require_regeneration() | ||||
| @@ -883,13 +886,13 @@ return function(_python, _bridged) | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return string | ||||
|     ---@return boolean | ||||
|     function KoboldLib_getters.is_custommodel(t) | ||||
|         return bridged.is_custommodel() | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@param v string | ||||
|     ---@param v boolean | ||||
|     function KoboldLib_setters.is_custommodel(t, v) | ||||
|         error("`KoboldLib.is_custommodel` is a read-only attribute") | ||||
|     end | ||||
| @@ -907,6 +910,65 @@ return function(_python, _bridged) | ||||
|     end | ||||
|  | ||||
|  | ||||
|     --========================================================================== | ||||
|     -- Userscript API: Logit Warping | ||||
|     --========================================================================== | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return integer | ||||
|     function KoboldLib_getters.logits_rows(t) | ||||
|         local backend = kobold.modelbackend | ||||
|         if backend == "readonly" or backend == "api" then | ||||
|             return 0 | ||||
|         end | ||||
|         return kobold.settings.numseqs | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return integer | ||||
|     function KoboldLib_setters.logits_rows(t) | ||||
|         error("`KoboldLib.logits_rows` is a read-only attribute") | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return integer | ||||
|     function KoboldLib_getters.logits_cols(t) | ||||
|         local backend = kobold.modelbackend | ||||
|         if backend == "readonly" or backend == "api" then | ||||
|             return 0 | ||||
|         end | ||||
|         return math.tointeger(koboldbridge.vocab_size) | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return integer | ||||
|     function KoboldLib_setters.logits_cols(t) | ||||
|         error("`KoboldLib.logits_cols` is a read-only attribute") | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@return table<integer, table<integer, number>> | ||||
|     function KoboldLib_getters.logits(t) | ||||
|         if koboldbridge.userstate ~= "genmod" then | ||||
|             return | ||||
|         end | ||||
|         return koboldbridge.logits | ||||
|     end | ||||
|  | ||||
|     ---@param t KoboldLib | ||||
|     ---@param v table<integer, table<integer, number>> | ||||
|     function KoboldLib_setters.logits(t, v) | ||||
|         if koboldbridge.userstate ~= "genmod" then | ||||
|             error("Cannot write to `KoboldLib.logits` from outside of a generation modifer") | ||||
|             return | ||||
|         elseif type(v) ~= "table" then | ||||
|             error("`KoboldLib.logits` must be a 2D list (table) of numbers; you attempted to set it to a " .. type(v)) | ||||
|             return | ||||
|         end | ||||
|         koboldbridge.logits = v | ||||
|     end | ||||
|  | ||||
|  | ||||
|     --========================================================================== | ||||
|     -- Userscript API: Utilities | ||||
|     --========================================================================== | ||||
| @@ -1386,6 +1448,16 @@ return function(_python, _bridged) | ||||
|         koboldbridge.userstate = "genmod" | ||||
|         if koboldbridge.genmod ~= nil then | ||||
|             r = koboldbridge.genmod() | ||||
|             setmetatable(koboldbridge.logits, nil) | ||||
|             for kr, vr in old_next, koboldbridge.logits, nil do | ||||
|                 setmetatable(vr, nil) | ||||
|                 for kc, vc in old_next, vr, nil do | ||||
|                     if type(vc) ~= "number" then | ||||
|                         error("`kobold.logits` must be a 2D table of numbers, but found a non-number element at row " .. kr .. ", column " .. kc) | ||||
|                         return r | ||||
|                     end | ||||
|                 end | ||||
|             end | ||||
|         end | ||||
|         return r | ||||
|     end | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Gnome Ann
					Gnome Ann