mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
Replace the search algorithm in Basic Phrase Bias with a different one
This commit is contained in:
parent
b3ced30e37
commit
aba150229c
@ -20,44 +20,6 @@ kobold = require("bridge")() -- This line is optional and is only for EmmyLua t
|
||||
local userscript = {} ---@class KoboldUserScript
|
||||
|
||||
|
||||
---@class Node
|
||||
local Node = {
|
||||
parent = nil,
|
||||
val = 0,
|
||||
depth = 0,
|
||||
g = nil, ---@type table<integer, Node>
|
||||
f = nil, ---@type Node
|
||||
entries = nil, ---@type table<integer, PhraseBiasEntry>
|
||||
n_entries = 0,
|
||||
}
|
||||
local Node_mt = {}
|
||||
setmetatable(Node, Node_mt)
|
||||
|
||||
local root ---@type Node
|
||||
local max_sequence_length = 0
|
||||
|
||||
---@param parent? Node
|
||||
---@param val integer
|
||||
---@return Node
|
||||
function Node.new(parent, val)
|
||||
local node = {} ---@type Node
|
||||
for k, v in pairs(Node) do
|
||||
node[k] = v
|
||||
end
|
||||
node.parent = parent
|
||||
if parent ~= nil then
|
||||
node.depth = parent.depth + 1
|
||||
end
|
||||
if val ~= nil then
|
||||
node.val = val
|
||||
end
|
||||
node.g = {}
|
||||
node.entries = {}
|
||||
setmetatable(node, Node_mt)
|
||||
return node
|
||||
end
|
||||
|
||||
|
||||
---@class PhraseBiasEntry
|
||||
---@field starting_bias number
|
||||
---@field ending_bias number
|
||||
@ -157,50 +119,6 @@ for line in f:lines("l") do
|
||||
end
|
||||
end
|
||||
f:close()
|
||||
|
||||
-- Offline preprocessing of config file for maximum speed
|
||||
root = Node.new()
|
||||
for i, entry in ipairs(bias_array) do
|
||||
local node = root
|
||||
for j, token in ipairs(entry.tokens) do
|
||||
if j > max_sequence_length then
|
||||
max_sequence_length = j
|
||||
end
|
||||
if node.g[token] == nil then
|
||||
node.g[token] = Node.new(node, token)
|
||||
node.g[token].parent = node
|
||||
end
|
||||
node = node.g[token]
|
||||
node.n_entries = node.n_entries + 1
|
||||
node.entries[node.n_entries] = entry
|
||||
end
|
||||
end
|
||||
---@class Linked
|
||||
local queue = {
|
||||
node = root,
|
||||
nxt = nil, ---@type Linked|nil
|
||||
}
|
||||
local queue_tail = queue
|
||||
while queue ~= nil do
|
||||
local node = queue.node
|
||||
for k, v in pairs(node.g) do
|
||||
queue_tail.nxt = {node = v}
|
||||
queue_tail = queue_tail.nxt
|
||||
end
|
||||
queue = queue.nxt
|
||||
node.f = root
|
||||
if node ~= root and node.parent ~= root then
|
||||
local ptr = node.parent
|
||||
while ptr ~= root do
|
||||
ptr = ptr.f
|
||||
if ptr.g[node.val] ~= nil then
|
||||
node.f = ptr.g[node.val]
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
print("Successfully loaded " .. bias_array_count .. " phrase bias entr" .. (bias_array_count == 1 and "y" or "ies") .. ".")
|
||||
|
||||
|
||||
@ -229,21 +147,22 @@ function userscript.genmod()
|
||||
local context_tokens = kobold.encode(kobold.worldinfo:compute_context(kobold.submission))
|
||||
local factor ---@type number
|
||||
local next_token ---@type integer
|
||||
local sequences = {} ---@type table<integer, table<integer, integer>>
|
||||
local n_tokens = 0
|
||||
local max_overlap = {} ---@type table<integer, integer>
|
||||
|
||||
local biased_tokens = {} ---@type table<integer, table<integer, boolean>>
|
||||
for i = 1, kobold.generated_rows do
|
||||
biased_tokens[i] = {}
|
||||
end
|
||||
|
||||
local max_overlap = {} ---@type table<integer, table<PhraseBiasEntry, integer>>
|
||||
|
||||
-- For each partially-generated sequence...
|
||||
for i, generated_row in ipairs(kobold.generated) do
|
||||
|
||||
-- Build an array `tokens` as the concatenation of the context
|
||||
-- tokens and the generated tokens of this sequence
|
||||
|
||||
tokens = {}
|
||||
local tokens = {}
|
||||
n_tokens = 0
|
||||
for k, v in ipairs(context_tokens) do
|
||||
n_tokens = n_tokens + 1
|
||||
@ -254,30 +173,47 @@ function userscript.genmod()
|
||||
tokens[n_tokens] = v
|
||||
end
|
||||
|
||||
-- For each phrase bias entry `bias_entry`, determine the largest
|
||||
-- integer `max_overlap[i][bias_entry]` such that the last
|
||||
-- `max_overlap[i][bias_entry]` elements of `tokens` equal the first
|
||||
-- `max_overlap[i][bias_entry]` elements of `bias_entry.tokens`
|
||||
-- For each phrase bias entry in the config file...
|
||||
for _, bias_entry in ipairs(bias_array) do
|
||||
|
||||
max_overlap[i] = {}
|
||||
local node = root
|
||||
for j = math.max(1, n_tokens - max_sequence_length + 1), n_tokens do
|
||||
local v = tokens[j]
|
||||
while node ~= root and node.g[v] == nil do
|
||||
node = node.f
|
||||
-- Determine the largest integer `max_overlap[i]` such that the last
|
||||
-- `max_overlap[i]` elements of `tokens` equal the first
|
||||
-- `max_overlap[i]` elements of `bias_entry.tokens`
|
||||
|
||||
max_overlap[i] = 0
|
||||
local s = {}
|
||||
local z = {[0] = 0}
|
||||
local l = 1
|
||||
local r = 1
|
||||
local n_s = math.min(n_tokens, bias_entry.n_tokens)
|
||||
local j = 0
|
||||
for k = 1, n_s do
|
||||
s[j] = bias_entry.tokens[k]
|
||||
j = j + 1
|
||||
end
|
||||
node = node.g[v]
|
||||
if node == nil then
|
||||
node = root
|
||||
for k = n_tokens - n_s + 1, n_tokens do
|
||||
s[j] = tokens[k]
|
||||
j = j + 1
|
||||
end
|
||||
end
|
||||
while node ~= root do
|
||||
for k, bias_entry in ipairs(node.entries) do
|
||||
if max_overlap[i][bias_entry] == nil then
|
||||
max_overlap[i][bias_entry] = node.depth
|
||||
for k = 1, (n_s<<1) - 1 do
|
||||
if k <= r and z[k - l] - 1 < r - k then
|
||||
z[k] = z[k - l]
|
||||
else
|
||||
l = k
|
||||
if k > r then
|
||||
r = k
|
||||
end
|
||||
while r < (n_s<<1) and s[r - l] == s[r] do
|
||||
r = r + 1
|
||||
end
|
||||
z[k] = r - l
|
||||
r = r - 1
|
||||
end
|
||||
if z[k] <= n_s and z[k] == (n_s<<1) - k then
|
||||
max_overlap[i] = z[k]
|
||||
break
|
||||
end
|
||||
end
|
||||
node = node.f
|
||||
end
|
||||
end
|
||||
|
||||
@ -287,14 +223,10 @@ function userscript.genmod()
|
||||
-- For each partially-generated sequence...
|
||||
for i, generated_row in ipairs(kobold.generated) do
|
||||
|
||||
if max_overlap[i][bias_entry] == nil then
|
||||
max_overlap[i][bias_entry] = 0
|
||||
end
|
||||
-- Use `max_overlap` to determine which token in the bias entry to
|
||||
-- apply bias to
|
||||
|
||||
-- Use `max_overlap[i][bias_entry]` to determine which token in the
|
||||
-- bias entry to apply bias to
|
||||
|
||||
if max_overlap[i][bias_entry] == 0 or max_overlap[i][bias_entry] == bias_entry.n_tokens then
|
||||
if max_overlap[i] == 0 or max_overlap[i] == bias_entry.n_tokens then
|
||||
if bias_entry.tokens[2] == nil then
|
||||
factor = 1
|
||||
else
|
||||
@ -302,8 +234,8 @@ function userscript.genmod()
|
||||
end
|
||||
next_token = bias_entry.tokens[1]
|
||||
else
|
||||
factor = max_overlap[i][bias_entry]/(bias_entry.n_tokens - 1)
|
||||
next_token = bias_entry.tokens[max_overlap[i][bias_entry]+1]
|
||||
factor = max_overlap[i]/(bias_entry.n_tokens - 1)
|
||||
next_token = bias_entry.tokens[max_overlap[i]+1]
|
||||
end
|
||||
|
||||
-- Apply bias
|
||||
|
Loading…
x
Reference in New Issue
Block a user