mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2024-12-12 00:35:41 +01:00
92 lines
2.3 KiB
Lua
92 lines
2.3 KiB
Lua
|
-- Logit viewer
|
||
|
-- Displays raw token scores and softmax probabilities during generation.
|
||
|
|
||
|
kobold = require("bridge")()
|
||
|
local userscript = {} ---@class KoboldUserScript
|
||
|
|
||
|
local K = 10
|
||
|
|
||
|
---@class Pair
|
||
|
---@field id integer
|
||
|
---@field score number
|
||
|
|
||
|
---@class ArrayBase
|
||
|
---@type table<any, Pair>
|
||
|
local _ = {}
|
||
|
|
||
|
---@class Array : ArrayBase
|
||
|
---@field n integer
|
||
|
|
||
|
---@param array Array
|
||
|
---@param index integer
|
||
|
---@return nil
|
||
|
local function bubble(array, index)
|
||
|
local j = 0
|
||
|
while (index<<1)+1 < array.n do
|
||
|
j = index
|
||
|
if array[(index<<1)+1].score > array[j].score then
|
||
|
j = (index<<1)+1
|
||
|
end
|
||
|
if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then
|
||
|
j = (index<<1)+2
|
||
|
end
|
||
|
if index == j then
|
||
|
break
|
||
|
end
|
||
|
local b = array[index]
|
||
|
array[index] = array[j]
|
||
|
array[j] = b
|
||
|
index = j
|
||
|
end
|
||
|
end
|
||
|
|
||
|
---@param array Array
|
||
|
---@return nil
|
||
|
local function build(array)
|
||
|
for i = (array.n-1)>>1, 0, -1 do
|
||
|
bubble(array, i)
|
||
|
end
|
||
|
end
|
||
|
|
||
|
---@param array Array
|
||
|
---@return Pair
|
||
|
local function pop(array)
|
||
|
local r = array[0]
|
||
|
array.n = array.n - 1
|
||
|
array[0] = array[array.n]
|
||
|
bubble(array, 0)
|
||
|
return r
|
||
|
end
|
||
|
|
||
|
function userscript.genmod()
|
||
|
if K > kobold.logits_cols then
|
||
|
error("K must be at most the vocabulary size of the model")
|
||
|
end
|
||
|
|
||
|
if kobold.generated_cols > 0 then
|
||
|
for s, logits in ipairs(kobold.logits) do
|
||
|
local token = kobold.generated[s][kobold.generated_cols]
|
||
|
print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")")
|
||
|
end
|
||
|
end
|
||
|
|
||
|
for s, logits in ipairs(kobold.logits) do
|
||
|
local a = {} ---@type Array
|
||
|
local sum = 0.0
|
||
|
for i = 0, kobold.logits_cols-1 do
|
||
|
a[i] = {id = i, score = logits[i + 1]}
|
||
|
a.n = i + 1
|
||
|
sum = sum + math.exp(logits[i + 1])
|
||
|
end
|
||
|
build(a)
|
||
|
print()
|
||
|
print("Top " .. K .. " scores for sequence " .. s .. ":")
|
||
|
for i = 1, K do
|
||
|
local e = pop(a)
|
||
|
print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]")
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
return userscript
|