2021-12-31 19:47:18 +01:00
-- Basic phrase bias
-- Makes certain sequences of tokens more or less likely to appear than normal.
2022-01-01 07:34:32 +01:00
-- This file is part of KoboldAI.
--
-- KoboldAI is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-- GNU Affero General Public License for more details.
--
-- You should have received a copy of the GNU Affero General Public License
-- along with this program. If not, see <https://www.gnu.org/licenses/>.
2021-12-31 19:47:18 +01:00
kobold = require ( " bridge " ) ( ) -- This line is optional and is only for EmmyLua type annotations
local userscript = { } ---@class KoboldUserScript
---@class PhraseBiasEntry
---@field starting_bias number
---@field ending_bias number
---@field tokens table<integer, integer>
---@field n_tokens integer
local example_config = [ [ # Phrase bias
#
# For each phrase you want to bias , add a new line into
# this config file as a comma - separated list in this format :
# < starting bias > , < ending bias > , < comma - separated list of token IDs >
# For < starting bias > and < ending bias > , this script accepts floating point
# numbers or - inf , where positive bias values make it more likely for tokens
# to appear , negative bias values make it less likely and - inf makes it
# impossible .
#
# Example 1 ( makes it impossible for the word " CHAPTER " , case - sensitive , to
# appear at the beginning of a line in the output ) :
# - inf , - inf , 41481
#
# Example 2 ( makes it unlikely for the word " CHAPTER " , case - sensitive , with
# a leading space , to appear in the output , with the unlikeliness increasing
# even more if the first token " CH " has appeared ) :
# - 10.0 , - 20.0 , 5870 , 29485
#
# Example 3 ( makes it more likely for " let the voice of love take you higher " ,
# case - sensitive , with a leading space , to appear in the output , with the
# bias increasing as each consecutive token in that phrase appears ) :
# 7 , 25.4 , 1309 , 262 , 3809 , 286 , 1842 , 1011 , 345 , 2440
#
] ]
-- If config file is empty, write example config
local f = kobold.get_config_file ( )
f : seek ( " set " )
if f : read ( 1 ) == nil then
f : write ( example_config )
end
f : seek ( " set " )
example_config = nil
-- Read config
print ( " Loading phrase bias config... " )
local bias_array = { } ---@type table<integer, PhraseBiasEntry>
local bias_array_count = 0
local val_count = 0
local line_count = 0
local row = { } ---@type PhraseBiasEntry
local val_orig
for line in f : lines ( " l " ) do
line_count = line_count + 1
if line : find ( " ^ *# " ) == nil and line : find ( " %S " ) ~= nil then
bias_array_count = bias_array_count + 1
val_count = 0
row = { }
row.tokens = { }
row.n_tokens = 0
for val in line : gmatch ( " [^,%s]+ " ) do
val_count = val_count + 1
val_orig = val
if val_count <= 2 then
val = val : lower ( )
if val : sub ( - 3 ) == " inf " then
val = math.tointeger ( val : sub ( 1 , - 4 ) .. " 1 " )
if val ~= val or type ( val ) ~= " number " or val > 0 then
f : close ( )
error ( " First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got ' " .. val_orig .. " ' as value # " .. val_count )
end
val = val * math.huge
else
val = tonumber ( val )
if val ~= val or type ( val ) ~= " number " then
f : close ( )
error ( " First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got ' " .. val_orig .. " ' as value # " .. val_count )
end
end
if val_count == 1 then
row.starting_bias = val
else
row.ending_bias = val
end
else
val = math.tointeger ( val )
if type ( val ) ~= " number " or val < 0 then
f : close ( )
error ( " All values after the first two values of line " .. line_count .. " of config file must be nonnegative integers, but got ' " .. val_orig .. " ' as value # " .. val_count )
end
row.n_tokens = row.n_tokens + 1
row.tokens [ row.n_tokens ] = val
end
end
if val_count < 3 then
f : close ( )
error ( " Line " .. line_count .. " of config file must contain at least 3 values, but found " .. val_count )
end
bias_array [ bias_array_count ] = row
end
end
f : close ( )
print ( " Successfully loaded " .. bias_array_count .. " phrase bias entr " .. ( bias_array_count == 1 and " y " or " ies " ) .. " . " )
local genmod_run = false
---@param starting_val number
---@param ending_val number
---@param factor number
---@return number
local function logit_interpolate ( starting_val , ending_val , factor )
-- First use the logistic function on the start and end values
starting_val = 1 / ( 1 + math.exp ( - starting_val ) )
ending_val = 1 / ( 1 + math.exp ( - ending_val ) )
-- Use linear interpolation between these two values
local val = starting_val + factor * ( ending_val - starting_val )
-- Return logit of this value
return math.log ( val / ( 1 - val ) )
end
function userscript . genmod ( )
genmod_run = true
local context_tokens = kobold.encode ( kobold.worldinfo : compute_context ( kobold.submission ) )
local factor ---@type number
local next_token ---@type integer
2022-01-01 18:23:15 +01:00
local sequences = { } ---@type table<integer, table<integer, integer>>
local n_tokens = 0
local max_overlap = { } ---@type table<integer, integer>
2021-12-31 19:47:18 +01:00
local biased_tokens = { } ---@type table<integer, table<integer, boolean>>
for i = 1 , kobold.generated_rows do
biased_tokens [ i ] = { }
end
-- For each partially-generated sequence...
for i , generated_row in ipairs ( kobold.generated ) do
-- Build an array `tokens` as the concatenation of the context
-- tokens and the generated tokens of this sequence
2022-01-01 18:23:15 +01:00
local tokens = { }
2021-12-31 19:47:18 +01:00
n_tokens = 0
for k , v in ipairs ( context_tokens ) do
n_tokens = n_tokens + 1
tokens [ n_tokens ] = v
end
for k , v in ipairs ( generated_row ) do
n_tokens = n_tokens + 1
tokens [ n_tokens ] = v
end
2022-01-01 18:23:15 +01:00
-- For each phrase bias entry in the config file...
for _ , bias_entry in ipairs ( bias_array ) do
-- Determine the largest integer `max_overlap[i]` such that the last
-- `max_overlap[i]` elements of `tokens` equal the first
-- `max_overlap[i]` elements of `bias_entry.tokens`
max_overlap [ i ] = 0
local s = { }
local z = { [ 0 ] = 0 }
local l = 1
local r = 1
local n_s = math.min ( n_tokens , bias_entry.n_tokens )
local j = 0
for k = 1 , n_s do
s [ j ] = bias_entry.tokens [ k ]
j = j + 1
2021-12-31 19:47:18 +01:00
end
2022-01-01 18:23:15 +01:00
for k = n_tokens - n_s + 1 , n_tokens do
s [ j ] = tokens [ k ]
j = j + 1
2021-12-31 19:47:18 +01:00
end
2022-01-01 18:23:15 +01:00
for k = 1 , ( n_s << 1 ) - 1 do
if k <= r and z [ k - l ] - 1 < r - k then
z [ k ] = z [ k - l ]
else
l = k
if k > r then
r = k
end
while r < ( n_s << 1 ) and s [ r - l ] == s [ r ] do
r = r + 1
end
z [ k ] = r - l
r = r - 1
end
if z [ k ] <= n_s and z [ k ] == ( n_s << 1 ) - k then
max_overlap [ i ] = z [ k ]
break
2021-12-31 19:47:18 +01:00
end
end
end
end
-- For each phrase bias entry in the config file...
for _ , bias_entry in ipairs ( bias_array ) do
-- For each partially-generated sequence...
for i , generated_row in ipairs ( kobold.generated ) do
2022-01-01 17:38:46 +01:00
2022-01-01 18:23:15 +01:00
-- Use `max_overlap` to determine which token in the bias entry to
-- apply bias to
2021-12-31 19:47:18 +01:00
2022-01-01 18:23:15 +01:00
if max_overlap [ i ] == 0 or max_overlap [ i ] == bias_entry.n_tokens then
2021-12-31 19:47:18 +01:00
if bias_entry.tokens [ 2 ] == nil then
factor = 1
else
factor = 0
end
next_token = bias_entry.tokens [ 1 ]
else
2022-01-01 18:23:15 +01:00
factor = max_overlap [ i ] / ( bias_entry.n_tokens - 1 )
next_token = bias_entry.tokens [ max_overlap [ i ] + 1 ]
2021-12-31 19:47:18 +01:00
end
-- Apply bias
if not biased_tokens [ i ] [ next_token ] then
kobold.logits [ i ] [ next_token + 1 ] = kobold.logits [ i ] [ next_token + 1 ] + logit_interpolate ( bias_entry.starting_bias , bias_entry.ending_bias , factor )
biased_tokens [ i ] [ next_token ] = true
end
end
end
end
function userscript . outmod ( )
if not genmod_run then
warn ( " WARNING: Generation modifier was not executed, so this script has had no effect " )
end
end
return userscript