diff --git a/userscripts/kaipreset_logits_viewer.lua b/userscripts/kaipreset_logits_viewer.lua new file mode 100644 index 00000000..eae883af --- /dev/null +++ b/userscripts/kaipreset_logits_viewer.lua @@ -0,0 +1,91 @@ +-- Logit viewer +-- Displays raw token scores and softmax probabilities during generation. + +kobold = require("bridge")() +local userscript = {} ---@class KoboldUserScript + +local K = 10 + +---@class Pair +---@field id integer +---@field score number + +---@class ArrayBase +---@type table +local _ = {} + +---@class Array : ArrayBase +---@field n integer + +---@param array Array +---@param index integer +---@return nil +local function bubble(array, index) + local j = 0 + while (index<<1)+1 < array.n do + j = index + if array[(index<<1)+1].score > array[j].score then + j = (index<<1)+1 + end + if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then + j = (index<<1)+2 + end + if index == j then + break + end + local b = array[index] + array[index] = array[j] + array[j] = b + index = j + end +end + +---@param array Array +---@return nil +local function build(array) + for i = (array.n-1)>>1, 0, -1 do + bubble(array, i) + end +end + +---@param array Array +---@return Pair +local function pop(array) + local r = array[0] + array.n = array.n - 1 + array[0] = array[array.n] + bubble(array, 0) + return r +end + +function userscript.genmod() + if K > kobold.logits_cols then + error("K must be at most the vocabulary size of the model") + end + + if kobold.generated_cols > 0 then + for s, logits in ipairs(kobold.logits) do + local token = kobold.generated[s][kobold.generated_cols] + print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")") + end + end + + for s, logits in ipairs(kobold.logits) do + local a = {} ---@type Array + local sum = 0.0 + for i = 0, kobold.logits_cols-1 do + a[i] = {id = i, score = logits[i + 1]} + a.n = i + 1 + sum = sum + math.exp(logits[i + 1]) + end + build(a) + print() + print("Top " .. K .. " scores for sequence " .. s .. ":") + for i = 1, K do + local e = pop(a) + print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]") + end + end +end + +return userscript