2023-08-24 20:23:35 +02:00
import { characters , main _api , nai _settings , online _status , this _chid } from "../script.js" ;
2023-08-27 22:20:43 +02:00
import { power _user , registerDebugFunction } from "./power-user.js" ;
2023-08-23 01:38:43 +02:00
import { chat _completion _sources , oai _settings } from "./openai.js" ;
import { groups , selected _group } from "./group-chats.js" ;
import { getStringHash } from "./utils.js" ;
2023-09-01 01:57:35 +02:00
import { kai _flags } from "./kai-settings.js" ;
2023-08-23 01:38:43 +02:00
export const CHARACTERS _PER _TOKEN _RATIO = 3.35 ;
2023-08-24 20:23:35 +02:00
const TOKENIZER _WARNING _KEY = 'tokenizationWarningShown' ;
2023-08-23 01:38:43 +02:00
export const tokenizers = {
NONE : 0 ,
2023-08-27 17:27:34 +02:00
GPT2 : 1 ,
2023-08-27 21:14:39 +02:00
/ * *
* @ deprecated Use GPT2 instead .
* /
LEGACY : 2 ,
2023-08-23 01:38:43 +02:00
LLAMA : 3 ,
NERD : 4 ,
NERD2 : 5 ,
API : 6 ,
BEST _MATCH : 99 ,
} ;
const objectStore = new localforage . createInstance ( { name : "SillyTavern_ChatCompletions" } ) ;
let tokenCache = { } ;
2023-08-24 19:19:57 +02:00
/ * *
* Guesstimates the token count for a string .
* @ param { string } str String to tokenize .
* @ returns { number } Token count .
* /
export function guesstimate ( str ) {
return Math . ceil ( str . length / CHARACTERS _PER _TOKEN _RATIO ) ;
}
2023-08-23 01:38:43 +02:00
async function loadTokenCache ( ) {
try {
console . debug ( 'Chat Completions: loading token cache' )
tokenCache = await objectStore . getItem ( 'tokenCache' ) || { } ;
} catch ( e ) {
console . log ( 'Chat Completions: unable to load token cache, using default value' , e ) ;
tokenCache = { } ;
}
}
export async function saveTokenCache ( ) {
try {
console . debug ( 'Chat Completions: saving token cache' )
await objectStore . setItem ( 'tokenCache' , tokenCache ) ;
} catch ( e ) {
console . log ( 'Chat Completions: unable to save token cache' , e ) ;
}
}
async function resetTokenCache ( ) {
try {
console . debug ( 'Chat Completions: resetting token cache' ) ;
Object . keys ( tokenCache ) . forEach ( key => delete tokenCache [ key ] ) ;
await objectStore . removeItem ( 'tokenCache' ) ;
2023-08-27 22:20:43 +02:00
toastr . success ( 'Token cache cleared. Please reload the chat to re-tokenize it.' ) ;
2023-08-23 01:38:43 +02:00
} catch ( e ) {
console . log ( 'Chat Completions: unable to reset token cache' , e ) ;
}
}
function getTokenizerBestMatch ( ) {
if ( main _api === 'novel' ) {
if ( nai _settings . model _novel . includes ( 'clio' ) ) {
return tokenizers . NERD ;
}
if ( nai _settings . model _novel . includes ( 'kayra' ) ) {
return tokenizers . NERD2 ;
}
}
if ( main _api === 'kobold' || main _api === 'textgenerationwebui' || main _api === 'koboldhorde' ) {
2023-08-24 20:23:35 +02:00
// Try to use the API tokenizer if possible:
// - API must be connected
// - Kobold must pass a version check
// - Tokenizer haven't reported an error previously
2023-09-01 01:57:35 +02:00
if ( kai _flags . can _use _tokenization && ! sessionStorage . getItem ( TOKENIZER _WARNING _KEY ) && online _status !== 'no_connection' ) {
2023-08-24 20:23:35 +02:00
return tokenizers . API ;
}
2023-08-23 01:38:43 +02:00
return tokenizers . LLAMA ;
}
return tokenizers . NONE ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to the token count for a string .
* @ param { number } type Tokenizer type .
* @ param { string } str String to tokenize .
* @ param { number } padding Number of padding tokens .
* @ returns { number } Token count .
* /
function callTokenizer ( type , str , padding ) {
switch ( type ) {
case tokenizers . NONE :
return guesstimate ( str ) + padding ;
case tokenizers . GPT2 :
return countTokensRemote ( '/tokenize_gpt2' , str , padding ) ;
case tokenizers . LLAMA :
return countTokensRemote ( '/tokenize_llama' , str , padding ) ;
case tokenizers . NERD :
return countTokensRemote ( '/tokenize_nerdstash' , str , padding ) ;
case tokenizers . NERD2 :
return countTokensRemote ( '/tokenize_nerdstash_v2' , str , padding ) ;
case tokenizers . API :
return countTokensRemote ( '/tokenize_via_api' , str , padding ) ;
default :
console . warn ( "Unknown tokenizer type" , type ) ;
return callTokenizer ( tokenizers . NONE , str , padding ) ;
}
}
2023-08-23 01:38:43 +02:00
/ * *
* Gets the token count for a string using the current model tokenizer .
* @ param { string } str String to tokenize
* @ param { number | undefined } padding Optional padding tokens . Defaults to 0.
* @ returns { number } Token count .
* /
export function getTokenCount ( str , padding = undefined ) {
if ( typeof str !== 'string' || ! str ? . length ) {
return 0 ;
}
let tokenizerType = power _user . tokenizer ;
if ( main _api === 'openai' ) {
if ( padding === power _user . token _padding ) {
// For main "shadow" prompt building
tokenizerType = tokenizers . NONE ;
} else {
// For extensions and WI
return counterWrapperOpenAI ( str ) ;
}
}
if ( tokenizerType === tokenizers . BEST _MATCH ) {
tokenizerType = getTokenizerBestMatch ( ) ;
}
if ( padding === undefined ) {
padding = 0 ;
}
const cacheObject = getTokenCacheObject ( ) ;
const hash = getStringHash ( str ) ;
2023-08-23 09:32:48 +02:00
const cacheKey = ` ${ tokenizerType } - ${ hash } + ${ padding } ` ;
2023-08-23 01:38:43 +02:00
if ( typeof cacheObject [ cacheKey ] === 'number' ) {
return cacheObject [ cacheKey ] ;
}
2023-08-27 17:27:34 +02:00
const result = callTokenizer ( tokenizerType , str , padding ) ;
2023-08-23 01:38:43 +02:00
if ( isNaN ( result ) ) {
console . warn ( "Token count calculation returned NaN" ) ;
return 0 ;
}
cacheObject [ cacheKey ] = result ;
return result ;
}
/ * *
* Gets the token count for a string using the OpenAI tokenizer .
* @ param { string } text Text to tokenize .
* @ returns { number } Token count .
* /
function counterWrapperOpenAI ( text ) {
const message = { role : 'system' , content : text } ;
return countTokensOpenAI ( message , true ) ;
}
export function getTokenizerModel ( ) {
// OpenAI models always provide their own tokenizer
if ( oai _settings . chat _completion _source == chat _completion _sources . OPENAI ) {
return oai _settings . openai _model ;
}
const turboTokenizer = 'gpt-3.5-turbo' ;
const gpt4Tokenizer = 'gpt-4' ;
const gpt2Tokenizer = 'gpt2' ;
const claudeTokenizer = 'claude' ;
// Assuming no one would use it for different models.. right?
if ( oai _settings . chat _completion _source == chat _completion _sources . SCALE ) {
return gpt4Tokenizer ;
}
// Select correct tokenizer for WindowAI proxies
if ( oai _settings . chat _completion _source == chat _completion _sources . WINDOWAI && oai _settings . windowai _model ) {
if ( oai _settings . windowai _model . includes ( 'gpt-4' ) ) {
return gpt4Tokenizer ;
}
else if ( oai _settings . windowai _model . includes ( 'gpt-3.5-turbo' ) ) {
return turboTokenizer ;
}
else if ( oai _settings . windowai _model . includes ( 'claude' ) ) {
return claudeTokenizer ;
}
else if ( oai _settings . windowai _model . includes ( 'GPT-NeoXT' ) ) {
return gpt2Tokenizer ;
}
}
// And for OpenRouter (if not a site model, then it's impossible to determine the tokenizer)
if ( oai _settings . chat _completion _source == chat _completion _sources . OPENROUTER && oai _settings . openrouter _model ) {
if ( oai _settings . openrouter _model . includes ( 'gpt-4' ) ) {
return gpt4Tokenizer ;
}
else if ( oai _settings . openrouter _model . includes ( 'gpt-3.5-turbo' ) ) {
return turboTokenizer ;
}
else if ( oai _settings . openrouter _model . includes ( 'claude' ) ) {
return claudeTokenizer ;
}
else if ( oai _settings . openrouter _model . includes ( 'GPT-NeoXT' ) ) {
return gpt2Tokenizer ;
}
}
if ( oai _settings . chat _completion _source == chat _completion _sources . CLAUDE ) {
return claudeTokenizer ;
}
// Default to Turbo 3.5
return turboTokenizer ;
}
/ * *
* @ param { any [ ] | Object } messages
* /
export function countTokensOpenAI ( messages , full = false ) {
const shouldTokenizeAI21 = oai _settings . chat _completion _source === chat _completion _sources . AI21 && oai _settings . use _ai21 _tokenizer ;
const cacheObject = getTokenCacheObject ( ) ;
if ( ! Array . isArray ( messages ) ) {
messages = [ messages ] ;
}
let token _count = - 1 ;
for ( const message of messages ) {
const model = getTokenizerModel ( ) ;
if ( model === 'claude' || shouldTokenizeAI21 ) {
full = true ;
}
const hash = getStringHash ( JSON . stringify ( message ) ) ;
const cacheKey = ` ${ model } - ${ hash } ` ;
const cachedCount = cacheObject [ cacheKey ] ;
if ( typeof cachedCount === 'number' ) {
token _count += cachedCount ;
}
else {
jQuery . ajax ( {
async : false ,
type : 'POST' , //
url : shouldTokenizeAI21 ? '/tokenize_ai21' : ` /tokenize_openai?model= ${ model } ` ,
data : JSON . stringify ( [ message ] ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
token _count += Number ( data . token _count ) ;
cacheObject [ cacheKey ] = Number ( data . token _count ) ;
}
} ) ;
}
}
if ( ! full ) token _count -= 2 ;
return token _count ;
}
/ * *
* Gets the token cache object for the current chat .
* @ returns { Object } Token cache object for the current chat .
* /
function getTokenCacheObject ( ) {
let chatId = 'undefined' ;
try {
if ( selected _group ) {
chatId = groups . find ( x => x . id == selected _group ) ? . chat _id ;
}
else if ( this _chid !== undefined ) {
chatId = characters [ this _chid ] . chat ;
}
} catch {
console . log ( 'No character / group selected. Using default cache item' ) ;
}
if ( typeof tokenCache [ chatId ] !== 'object' ) {
tokenCache [ chatId ] = { } ;
}
return tokenCache [ String ( chatId ) ] ;
}
2023-08-24 19:19:57 +02:00
/ * *
* Counts token using the remote server API .
* @ param { string } endpoint API endpoint .
* @ param { string } str String to tokenize .
* @ param { number } padding Number of padding tokens .
* @ returns { number } Token count with padding .
* /
2023-08-23 01:38:43 +02:00
function countTokensRemote ( endpoint , str , padding ) {
let tokenCount = 0 ;
2023-08-24 19:19:57 +02:00
2023-08-23 01:38:43 +02:00
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
data : JSON . stringify ( { text : str } ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
2023-08-24 19:19:57 +02:00
if ( typeof data . count === 'number' ) {
tokenCount = data . count ;
} else {
tokenCount = guesstimate ( str ) ;
console . error ( "Error counting tokens" ) ;
2023-08-24 20:23:35 +02:00
if ( ! sessionStorage . getItem ( TOKENIZER _WARNING _KEY ) ) {
2023-08-24 19:19:57 +02:00
toastr . warning (
"Your selected API doesn't support the tokenization endpoint. Using estimated counts." ,
"Error counting tokens" ,
{ timeOut : 10000 , preventDuplicates : true } ,
) ;
2023-08-24 20:23:35 +02:00
sessionStorage . setItem ( TOKENIZER _WARNING _KEY , String ( true ) ) ;
2023-08-24 19:19:57 +02:00
}
}
2023-08-23 01:38:43 +02:00
}
} ) ;
2023-08-24 19:19:57 +02:00
2023-08-23 01:38:43 +02:00
return tokenCount + padding ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to encode a string to tokens .
* @ param { string } endpoint API endpoint .
* @ param { string } str String to tokenize .
* @ returns { number [ ] } Array of token ids .
* /
2023-08-23 01:38:43 +02:00
function getTextTokensRemote ( endpoint , str ) {
let ids = [ ] ;
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
data : JSON . stringify ( { text : str } ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
ids = data . ids ;
}
} ) ;
return ids ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to decode token ids to text .
* @ param { string } endpoint API endpoint .
* @ param { number [ ] } ids Array of token ids
* /
function decodeTextTokensRemote ( endpoint , ids ) {
let text = '' ;
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
data : JSON . stringify ( { ids : ids } ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
text = data . text ;
}
} ) ;
return text ;
}
/ * *
* Encodes a string to tokens using the remote server API .
* @ param { number } tokenizerType Tokenizer type .
* @ param { string } str String to tokenize .
* @ returns { number [ ] } Array of token ids .
* /
2023-08-23 01:38:43 +02:00
export function getTextTokens ( tokenizerType , str ) {
switch ( tokenizerType ) {
2023-08-27 17:27:34 +02:00
case tokenizers . GPT2 :
return getTextTokensRemote ( '/tokenize_gpt2' , str ) ;
2023-08-23 01:38:43 +02:00
case tokenizers . LLAMA :
return getTextTokensRemote ( '/tokenize_llama' , str ) ;
case tokenizers . NERD :
return getTextTokensRemote ( '/tokenize_nerdstash' , str ) ;
case tokenizers . NERD2 :
return getTextTokensRemote ( '/tokenize_nerdstash_v2' , str ) ;
default :
console . warn ( "Calling getTextTokens with unsupported tokenizer type" , tokenizerType ) ;
return [ ] ;
}
}
2023-08-27 17:27:34 +02:00
/ * *
* Decodes token ids to text using the remote server API .
* @ param { any } tokenizerType Tokenizer type .
* @ param { number [ ] } ids Array of token ids
* /
export function decodeTextTokens ( tokenizerType , ids ) {
switch ( tokenizerType ) {
case tokenizers . GPT2 :
return decodeTextTokensRemote ( '/decode_gpt2' , ids ) ;
case tokenizers . LLAMA :
return decodeTextTokensRemote ( '/decode_llama' , ids ) ;
case tokenizers . NERD :
return decodeTextTokensRemote ( '/decode_nerdstash' , ids ) ;
case tokenizers . NERD2 :
return decodeTextTokensRemote ( '/decode_nerdstash_v2' , ids ) ;
default :
console . warn ( "Calling decodeTextTokens with unsupported tokenizer type" , tokenizerType ) ;
return '' ;
}
}
2023-08-23 01:38:43 +02:00
jQuery ( async ( ) => {
await loadTokenCache ( ) ;
2023-08-27 22:20:43 +02:00
registerDebugFunction ( 'resetTokenCache' , 'Reset token cache' , 'Purges the calculated token counts. Use this if you want to force a full re-tokenization of all chats or suspect the token counts are wrong.' , resetTokenCache ) ;
2023-08-23 01:38:43 +02:00
} ) ;