-- Basic phrase bias -- Makes certain sequences of tokens more or less likely to appear than normal. -- This file is part of KoboldAI. -- -- KoboldAI is free software: you can redistribute it and/or modify -- it under the terms of the GNU Affero General Public License as published by -- the Free Software Foundation, either version 3 of the License, or -- (at your option) any later version. -- -- This program is distributed in the hope that it will be useful, -- but WITHOUT ANY WARRANTY; without even the implied warranty of -- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -- GNU Affero General Public License for more details. -- -- You should have received a copy of the GNU Affero General Public License -- along with this program. If not, see . kobold = require("bridge")() -- This line is optional and is only for EmmyLua type annotations local userscript = {} ---@class KoboldUserScript ---@class Node local Node = { parent = nil, val = 0, depth = 0, g = nil, ---@type table f = nil, ---@type Node entries = nil, ---@type table n_entries = 0, } local Node_mt = {} setmetatable(Node, Node_mt) local root ---@type Node local max_sequence_length = 0 ---@param parent? Node ---@param val integer ---@return Node function Node.new(parent, val) local node = {} ---@type Node for k, v in pairs(Node) do node[k] = v end node.parent = parent if parent ~= nil then node.depth = parent.depth + 1 end if val ~= nil then node.val = val end node.g = {} node.entries = {} setmetatable(node, Node_mt) return node end ---@class PhraseBiasEntry ---@field starting_bias number ---@field ending_bias number ---@field tokens table ---@field n_tokens integer local example_config = [[# Phrase bias # # For each phrase you want to bias, add a new line into # this config file as a comma-separated list in this format: # , , # For and , this script accepts floating point # numbers or -inf, where positive bias values make it more likely for tokens # to appear, negative bias values make it less likely and -inf makes it # impossible. # # Example 1 (makes it impossible for the word "CHAPTER", case-sensitive, to # appear at the beginning of a line in the output): # -inf, -inf, 41481 # # Example 2 (makes it unlikely for the word " CHAPTER", case-sensitive, with # a leading space, to appear in the output, with the unlikeliness increasing # even more if the first token " CH" has appeared): # -10.0, -20.0, 5870, 29485 # # Example 3 (makes it more likely for " let the voice of love take you higher", # case-sensitive, with a leading space, to appear in the output, with the # bias increasing as each consecutive token in that phrase appears): # 7, 25.4, 1309, 262, 3809, 286, 1842, 1011, 345, 2440 # ]] -- If config file is empty, write example config local f = kobold.get_config_file() f:seek("set") if f:read(1) == nil then f:write(example_config) end f:seek("set") example_config = nil -- Read config print("Loading phrase bias config...") local bias_array = {} ---@type table local bias_array_count = 0 local val_count = 0 local line_count = 0 local row = {} ---@type PhraseBiasEntry local val_orig for line in f:lines("l") do line_count = line_count + 1 if line:find("^ *#") == nil and line:find("%S") ~= nil then bias_array_count = bias_array_count + 1 val_count = 0 row = {} row.tokens = {} row.n_tokens = 0 for val in line:gmatch("[^,%s]+") do val_count = val_count + 1 val_orig = val if val_count <= 2 then val = val:lower() if val:sub(-3) == "inf" then val = math.tointeger(val:sub(1, -4) .. "1") if val ~= val or type(val) ~= "number" or val > 0 then f:close() error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count) end val = val * math.huge else val = tonumber(val) if val ~= val or type(val) ~= "number" then f:close() error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count) end end if val_count == 1 then row.starting_bias = val else row.ending_bias = val end else val = math.tointeger(val) if type(val) ~= "number" or val < 0 then f:close() error("All values after the first two values of line " .. line_count .. " of config file must be nonnegative integers, but got '" .. val_orig .. "' as value #" .. val_count) end row.n_tokens = row.n_tokens + 1 row.tokens[row.n_tokens] = val end end if val_count < 3 then f:close() error("Line " .. line_count .. " of config file must contain at least 3 values, but found " .. val_count) end bias_array[bias_array_count] = row end end f:close() -- Offline preprocessing of config file for maximum speed root = Node.new() for i, entry in ipairs(bias_array) do local node = root for j, token in ipairs(entry.tokens) do if j > max_sequence_length then max_sequence_length = j end if node.g[token] == nil then node.g[token] = Node.new(node, token) node.g[token].parent = node end node = node.g[token] node.n_entries = node.n_entries + 1 node.entries[node.n_entries] = entry end end ---@class Linked local queue = { node = root, nxt = nil, ---@type Linked|nil } local queue_tail = queue while queue ~= nil do local node = queue.node for k, v in pairs(node.g) do queue_tail.nxt = {node = v} queue_tail = queue_tail.nxt end queue = queue.nxt node.f = root if node ~= root and node.parent ~= root then local ptr = node.parent while ptr ~= root do ptr = ptr.f if ptr.g[node.val] ~= nil then node.f = ptr.g[node.val] break end end end end print("Successfully loaded " .. bias_array_count .. " phrase bias entr" .. (bias_array_count == 1 and "y" or "ies") .. ".") local genmod_run = false ---@param starting_val number ---@param ending_val number ---@param factor number ---@return number local function logit_interpolate(starting_val, ending_val, factor) -- First use the logistic function on the start and end values starting_val = 1/(1 + math.exp(-starting_val)) ending_val = 1/(1 + math.exp(-ending_val)) -- Use linear interpolation between these two values local val = starting_val + factor*(ending_val - starting_val) -- Return logit of this value return math.log(val/(1 - val)) end function userscript.genmod() genmod_run = true local context_tokens = kobold.encode(kobold.worldinfo:compute_context(kobold.submission)) local factor ---@type number local next_token ---@type integer local biased_tokens = {} ---@type table> for i = 1, kobold.generated_rows do biased_tokens[i] = {} end local max_overlap = {} ---@type table> -- For each partially-generated sequence... for i, generated_row in ipairs(kobold.generated) do -- Build an array `tokens` as the concatenation of the context -- tokens and the generated tokens of this sequence tokens = {} n_tokens = 0 for k, v in ipairs(context_tokens) do n_tokens = n_tokens + 1 tokens[n_tokens] = v end for k, v in ipairs(generated_row) do n_tokens = n_tokens + 1 tokens[n_tokens] = v end -- For each phrase bias entry `bias_entry`, determine the largest -- integer `max_overlap[i][bias_entry]` such that the last -- `max_overlap[i][bias_entry]` elements of `tokens` equal the first -- `max_overlap[i][bias_entry]` elements of `bias_entry.tokens` max_overlap[i] = {} local node = root for j = math.max(1, n_tokens - max_sequence_length + 1), n_tokens do local v = tokens[j] while node ~= root and node.g[v] == nil do node = node.f end node = node.g[v] if node == nil then node = root end end while node ~= root do for k, bias_entry in ipairs(node.entries) do if max_overlap[i][bias_entry] == nil then max_overlap[i][bias_entry] = node.depth end end node = node.f end end -- For each phrase bias entry in the config file... for _, bias_entry in ipairs(bias_array) do -- For each partially-generated sequence... for i, generated_row in ipairs(kobold.generated) do if max_overlap[i][bias_entry] == nil then max_overlap[i][bias_entry] = 0 end -- Use `max_overlap[i][bias_entry]` to determine which token in the -- bias entry to apply bias to if max_overlap[i][bias_entry] == 0 or max_overlap[i][bias_entry] == bias_entry.n_tokens then if bias_entry.tokens[2] == nil then factor = 1 else factor = 0 end next_token = bias_entry.tokens[1] else factor = max_overlap[i][bias_entry]/(bias_entry.n_tokens - 1) next_token = bias_entry.tokens[max_overlap[i][bias_entry]+1] end -- Apply bias if not biased_tokens[i][next_token] then kobold.logits[i][next_token + 1] = kobold.logits[i][next_token + 1] + logit_interpolate(bias_entry.starting_bias, bias_entry.ending_bias, factor) biased_tokens[i][next_token] = true end end end end function userscript.outmod() if not genmod_run then warn("WARNING: Generation modifier was not executed, so this script has had no effect") end end return userscript