-- Logit viewer -- Displays raw token scores and softmax probabilities during generation. kobold = require("bridge")() local userscript = {} ---@class KoboldUserScript local K = 10 ---@class Pair ---@field id integer ---@field score number ---@class ArrayBase ---@type table<any, Pair> local _ = {} ---@class Array : ArrayBase ---@field n integer ---@param array Array ---@param index integer ---@return nil local function bubble(array, index) local j = 0 while (index<<1)+1 < array.n do j = index if array[(index<<1)+1].score > array[j].score then j = (index<<1)+1 end if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then j = (index<<1)+2 end if index == j then break end local b = array[index] array[index] = array[j] array[j] = b index = j end end ---@param array Array ---@return nil local function build(array) for i = (array.n-1)>>1, 0, -1 do bubble(array, i) end end ---@param array Array ---@return Pair local function pop(array) local r = array[0] array.n = array.n - 1 array[0] = array[array.n] bubble(array, 0) return r end function userscript.genmod() if K > kobold.logits_cols then error("K must be at most the vocabulary size of the model") end if kobold.generated_cols > 0 then for s, logits in ipairs(kobold.logits) do local token = kobold.generated[s][kobold.generated_cols] print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")") end end for s, logits in ipairs(kobold.logits) do local a = {} ---@type Array local sum = 0.0 for i = 0, kobold.logits_cols-1 do a[i] = {id = i, score = logits[i + 1]} a.n = i + 1 sum = sum + math.exp(logits[i + 1]) end build(a) print() print("Top " .. K .. " scores for sequence " .. s .. ":") for i = 1, K do local e = pop(a) print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]") end end end return userscript