2023-11-08 09:13:28 +01:00
import { characters , getAPIServerUrl , main _api , nai _settings , online _status , this _chid } from "../script.js" ;
2023-08-27 22:20:43 +02:00
import { power _user , registerDebugFunction } from "./power-user.js" ;
2023-11-05 20:54:19 +01:00
import { chat _completion _sources , model _list , oai _settings } from "./openai.js" ;
2023-08-23 01:38:43 +02:00
import { groups , selected _group } from "./group-chats.js" ;
import { getStringHash } from "./utils.js" ;
2023-09-01 01:57:35 +02:00
import { kai _flags } from "./kai-settings.js" ;
2023-11-19 16:14:53 +01:00
import { isKoboldCpp , isMancer , isTabby , textgenerationwebui _settings } from "./textgen-settings.js" ;
2023-08-23 01:38:43 +02:00
export const CHARACTERS _PER _TOKEN _RATIO = 3.35 ;
2023-08-24 20:23:35 +02:00
const TOKENIZER _WARNING _KEY = 'tokenizationWarningShown' ;
2023-08-23 01:38:43 +02:00
export const tokenizers = {
NONE : 0 ,
2023-08-27 17:27:34 +02:00
GPT2 : 1 ,
2023-11-05 21:45:37 +01:00
OPENAI : 2 ,
2023-08-23 01:38:43 +02:00
LLAMA : 3 ,
NERD : 4 ,
NERD2 : 5 ,
API : 6 ,
2023-11-06 00:26:13 +01:00
MISTRAL : 7 ,
2023-11-20 23:21:58 +01:00
YI : 8 ,
2023-08-23 01:38:43 +02:00
BEST _MATCH : 99 ,
} ;
2023-11-21 00:04:27 +01:00
export const SENTENCEPIECE _TOKENIZERS = [
tokenizers . LLAMA ,
tokenizers . MISTRAL ,
tokenizers . YI ,
// uncomment when NovelAI releases Kayra and Clio weights, lol
//tokenizers.NERD,
//tokenizers.NERD2,
] ;
2023-08-23 01:38:43 +02:00
const objectStore = new localforage . createInstance ( { name : "SillyTavern_ChatCompletions" } ) ;
let tokenCache = { } ;
2023-08-24 19:19:57 +02:00
/ * *
* Guesstimates the token count for a string .
* @ param { string } str String to tokenize .
* @ returns { number } Token count .
* /
export function guesstimate ( str ) {
return Math . ceil ( str . length / CHARACTERS _PER _TOKEN _RATIO ) ;
}
2023-08-23 01:38:43 +02:00
async function loadTokenCache ( ) {
try {
console . debug ( 'Chat Completions: loading token cache' )
tokenCache = await objectStore . getItem ( 'tokenCache' ) || { } ;
} catch ( e ) {
console . log ( 'Chat Completions: unable to load token cache, using default value' , e ) ;
tokenCache = { } ;
}
}
export async function saveTokenCache ( ) {
try {
console . debug ( 'Chat Completions: saving token cache' )
await objectStore . setItem ( 'tokenCache' , tokenCache ) ;
} catch ( e ) {
console . log ( 'Chat Completions: unable to save token cache' , e ) ;
}
}
async function resetTokenCache ( ) {
try {
console . debug ( 'Chat Completions: resetting token cache' ) ;
Object . keys ( tokenCache ) . forEach ( key => delete tokenCache [ key ] ) ;
await objectStore . removeItem ( 'tokenCache' ) ;
2023-08-27 22:20:43 +02:00
toastr . success ( 'Token cache cleared. Please reload the chat to re-tokenize it.' ) ;
2023-08-23 01:38:43 +02:00
} catch ( e ) {
console . log ( 'Chat Completions: unable to reset token cache' , e ) ;
}
}
2023-11-06 19:25:59 +01:00
/ * *
* Gets the friendly name of the current tokenizer .
* @ param { string } forApi API to get the tokenizer for . Defaults to the main API .
* @ returns { { tokenizerName : string , tokenizerId : number } } Tokenizer info
* /
export function getFriendlyTokenizerName ( forApi ) {
if ( ! forApi ) {
forApi = main _api ;
}
const tokenizerOption = $ ( "#tokenizer" ) . find ( ':selected' ) ;
let tokenizerId = Number ( tokenizerOption . val ( ) ) ;
let tokenizerName = tokenizerOption . text ( ) ;
if ( forApi !== 'openai' && tokenizerId === tokenizers . BEST _MATCH ) {
tokenizerId = getTokenizerBestMatch ( forApi ) ;
tokenizerName = $ ( ` #tokenizer option[value=" ${ tokenizerId } "] ` ) . text ( ) ;
}
tokenizerName = forApi == 'openai'
? getTokenizerModel ( )
: tokenizerName ;
tokenizerId = forApi == 'openai'
? tokenizers . OPENAI
: tokenizerId ;
return { tokenizerName , tokenizerId } ;
}
/ * *
* Gets the best tokenizer for the current API .
* @ param { string } forApi API to get the tokenizer for . Defaults to the main API .
* @ returns { number } Tokenizer type .
* /
export function getTokenizerBestMatch ( forApi ) {
if ( ! forApi ) {
forApi = main _api ;
}
if ( forApi === 'novel' ) {
2023-08-23 01:38:43 +02:00
if ( nai _settings . model _novel . includes ( 'clio' ) ) {
return tokenizers . NERD ;
}
if ( nai _settings . model _novel . includes ( 'kayra' ) ) {
return tokenizers . NERD2 ;
}
}
2023-11-06 19:25:59 +01:00
if ( forApi === 'kobold' || forApi === 'textgenerationwebui' || forApi === 'koboldhorde' ) {
2023-08-24 20:23:35 +02:00
// Try to use the API tokenizer if possible:
// - API must be connected
// - Kobold must pass a version check
// - Tokenizer haven't reported an error previously
2023-09-01 01:57:35 +02:00
if ( kai _flags . can _use _tokenization && ! sessionStorage . getItem ( TOKENIZER _WARNING _KEY ) && online _status !== 'no_connection' ) {
2023-08-24 20:23:35 +02:00
return tokenizers . API ;
}
2023-08-23 01:38:43 +02:00
return tokenizers . LLAMA ;
}
return tokenizers . NONE ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to the token count for a string .
* @ param { number } type Tokenizer type .
* @ param { string } str String to tokenize .
* @ param { number } padding Number of padding tokens .
* @ returns { number } Token count .
* /
function callTokenizer ( type , str , padding ) {
switch ( type ) {
case tokenizers . NONE :
return guesstimate ( str ) + padding ;
case tokenizers . GPT2 :
2023-09-16 17:48:06 +02:00
return countTokensRemote ( '/api/tokenize/gpt2' , str , padding ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . LLAMA :
2023-09-16 17:48:06 +02:00
return countTokensRemote ( '/api/tokenize/llama' , str , padding ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . NERD :
2023-09-16 17:48:06 +02:00
return countTokensRemote ( '/api/tokenize/nerdstash' , str , padding ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . NERD2 :
2023-09-16 17:48:06 +02:00
return countTokensRemote ( '/api/tokenize/nerdstash_v2' , str , padding ) ;
2023-11-06 00:26:13 +01:00
case tokenizers . MISTRAL :
return countTokensRemote ( '/api/tokenize/mistral' , str , padding ) ;
2023-11-20 23:21:58 +01:00
case tokenizers . YI :
return countTokensRemote ( '/api/tokenize/yi' , str , padding ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . API :
return countTokensRemote ( '/tokenize_via_api' , str , padding ) ;
default :
console . warn ( "Unknown tokenizer type" , type ) ;
return callTokenizer ( tokenizers . NONE , str , padding ) ;
}
}
2023-08-23 01:38:43 +02:00
/ * *
* Gets the token count for a string using the current model tokenizer .
* @ param { string } str String to tokenize
* @ param { number | undefined } padding Optional padding tokens . Defaults to 0.
* @ returns { number } Token count .
* /
export function getTokenCount ( str , padding = undefined ) {
if ( typeof str !== 'string' || ! str ? . length ) {
return 0 ;
}
let tokenizerType = power _user . tokenizer ;
if ( main _api === 'openai' ) {
if ( padding === power _user . token _padding ) {
// For main "shadow" prompt building
tokenizerType = tokenizers . NONE ;
} else {
// For extensions and WI
return counterWrapperOpenAI ( str ) ;
}
}
if ( tokenizerType === tokenizers . BEST _MATCH ) {
2023-11-06 19:25:59 +01:00
tokenizerType = getTokenizerBestMatch ( main _api ) ;
2023-08-23 01:38:43 +02:00
}
if ( padding === undefined ) {
padding = 0 ;
}
const cacheObject = getTokenCacheObject ( ) ;
const hash = getStringHash ( str ) ;
2023-08-23 09:32:48 +02:00
const cacheKey = ` ${ tokenizerType } - ${ hash } + ${ padding } ` ;
2023-08-23 01:38:43 +02:00
if ( typeof cacheObject [ cacheKey ] === 'number' ) {
return cacheObject [ cacheKey ] ;
}
2023-08-27 17:27:34 +02:00
const result = callTokenizer ( tokenizerType , str , padding ) ;
2023-08-23 01:38:43 +02:00
if ( isNaN ( result ) ) {
console . warn ( "Token count calculation returned NaN" ) ;
return 0 ;
}
cacheObject [ cacheKey ] = result ;
return result ;
}
/ * *
* Gets the token count for a string using the OpenAI tokenizer .
* @ param { string } text Text to tokenize .
* @ returns { number } Token count .
* /
function counterWrapperOpenAI ( text ) {
const message = { role : 'system' , content : text } ;
return countTokensOpenAI ( message , true ) ;
}
export function getTokenizerModel ( ) {
// OpenAI models always provide their own tokenizer
if ( oai _settings . chat _completion _source == chat _completion _sources . OPENAI ) {
return oai _settings . openai _model ;
}
2023-10-19 12:37:08 +02:00
const turbo0301Tokenizer = 'gpt-3.5-turbo-0301' ;
2023-08-23 01:38:43 +02:00
const turboTokenizer = 'gpt-3.5-turbo' ;
const gpt4Tokenizer = 'gpt-4' ;
const gpt2Tokenizer = 'gpt2' ;
const claudeTokenizer = 'claude' ;
2023-11-05 20:54:19 +01:00
const llamaTokenizer = 'llama' ;
2023-11-06 00:26:13 +01:00
const mistralTokenizer = 'mistral' ;
2023-11-20 23:21:58 +01:00
const yiTokenizer = 'yi' ;
2023-08-23 01:38:43 +02:00
// Assuming no one would use it for different models.. right?
if ( oai _settings . chat _completion _source == chat _completion _sources . SCALE ) {
return gpt4Tokenizer ;
}
// Select correct tokenizer for WindowAI proxies
if ( oai _settings . chat _completion _source == chat _completion _sources . WINDOWAI && oai _settings . windowai _model ) {
if ( oai _settings . windowai _model . includes ( 'gpt-4' ) ) {
return gpt4Tokenizer ;
}
2023-10-19 12:37:08 +02:00
else if ( oai _settings . windowai _model . includes ( 'gpt-3.5-turbo-0301' ) ) {
return turbo0301Tokenizer ;
}
2023-08-23 01:38:43 +02:00
else if ( oai _settings . windowai _model . includes ( 'gpt-3.5-turbo' ) ) {
return turboTokenizer ;
}
else if ( oai _settings . windowai _model . includes ( 'claude' ) ) {
return claudeTokenizer ;
}
else if ( oai _settings . windowai _model . includes ( 'GPT-NeoXT' ) ) {
return gpt2Tokenizer ;
}
}
// And for OpenRouter (if not a site model, then it's impossible to determine the tokenizer)
if ( oai _settings . chat _completion _source == chat _completion _sources . OPENROUTER && oai _settings . openrouter _model ) {
2023-11-05 20:54:19 +01:00
const model = model _list . find ( x => x . id === oai _settings . openrouter _model ) ;
if ( model ? . architecture ? . tokenizer === 'Llama2' ) {
return llamaTokenizer ;
}
2023-11-06 00:26:13 +01:00
else if ( model ? . architecture ? . tokenizer === 'Mistral' ) {
return mistralTokenizer ;
}
2023-11-20 23:21:58 +01:00
else if ( model ? . architecture ? . tokenizer === 'Yi' ) {
return yiTokenizer ;
}
2023-11-05 20:54:19 +01:00
else if ( oai _settings . openrouter _model . includes ( 'gpt-4' ) ) {
2023-08-23 01:38:43 +02:00
return gpt4Tokenizer ;
}
2023-10-19 12:37:08 +02:00
else if ( oai _settings . openrouter _model . includes ( 'gpt-3.5-turbo-0301' ) ) {
return turbo0301Tokenizer ;
}
2023-08-23 01:38:43 +02:00
else if ( oai _settings . openrouter _model . includes ( 'gpt-3.5-turbo' ) ) {
return turboTokenizer ;
}
else if ( oai _settings . openrouter _model . includes ( 'claude' ) ) {
return claudeTokenizer ;
}
else if ( oai _settings . openrouter _model . includes ( 'GPT-NeoXT' ) ) {
return gpt2Tokenizer ;
}
}
if ( oai _settings . chat _completion _source == chat _completion _sources . CLAUDE ) {
return claudeTokenizer ;
}
// Default to Turbo 3.5
return turboTokenizer ;
}
/ * *
* @ param { any [ ] | Object } messages
* /
export function countTokensOpenAI ( messages , full = false ) {
const shouldTokenizeAI21 = oai _settings . chat _completion _source === chat _completion _sources . AI21 && oai _settings . use _ai21 _tokenizer ;
const cacheObject = getTokenCacheObject ( ) ;
if ( ! Array . isArray ( messages ) ) {
messages = [ messages ] ;
}
let token _count = - 1 ;
for ( const message of messages ) {
const model = getTokenizerModel ( ) ;
if ( model === 'claude' || shouldTokenizeAI21 ) {
full = true ;
}
const hash = getStringHash ( JSON . stringify ( message ) ) ;
const cacheKey = ` ${ model } - ${ hash } ` ;
const cachedCount = cacheObject [ cacheKey ] ;
if ( typeof cachedCount === 'number' ) {
token _count += cachedCount ;
}
else {
jQuery . ajax ( {
async : false ,
type : 'POST' , //
2023-09-16 17:48:06 +02:00
url : shouldTokenizeAI21 ? '/api/tokenize/ai21' : ` /api/tokenize/openai?model= ${ model } ` ,
2023-08-23 01:38:43 +02:00
data : JSON . stringify ( [ message ] ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
token _count += Number ( data . token _count ) ;
cacheObject [ cacheKey ] = Number ( data . token _count ) ;
}
} ) ;
}
}
if ( ! full ) token _count -= 2 ;
return token _count ;
}
/ * *
* Gets the token cache object for the current chat .
* @ returns { Object } Token cache object for the current chat .
* /
function getTokenCacheObject ( ) {
let chatId = 'undefined' ;
try {
if ( selected _group ) {
chatId = groups . find ( x => x . id == selected _group ) ? . chat _id ;
}
else if ( this _chid !== undefined ) {
chatId = characters [ this _chid ] . chat ;
}
} catch {
console . log ( 'No character / group selected. Using default cache item' ) ;
}
if ( typeof tokenCache [ chatId ] !== 'object' ) {
tokenCache [ chatId ] = { } ;
}
return tokenCache [ String ( chatId ) ] ;
}
2023-11-09 18:39:08 +01:00
function getRemoteTokenizationParams ( str ) {
return {
text : str ,
api : main _api ,
url : getAPIServerUrl ( ) ,
legacy _api : main _api === 'textgenerationwebui' && textgenerationwebui _settings . legacy _api && ! isMancer ( ) ,
2023-11-19 16:14:53 +01:00
use _tabby : main _api === 'textgenerationwebui' && isTabby ( ) ,
use _koboldcpp : main _api === 'textgenerationwebui' && isKoboldCpp ( ) ,
2023-11-09 18:39:08 +01:00
} ;
}
2023-08-24 19:19:57 +02:00
/ * *
* Counts token using the remote server API .
* @ param { string } endpoint API endpoint .
* @ param { string } str String to tokenize .
* @ param { number } padding Number of padding tokens .
* @ returns { number } Token count with padding .
* /
2023-08-23 01:38:43 +02:00
function countTokensRemote ( endpoint , str , padding ) {
let tokenCount = 0 ;
2023-08-24 19:19:57 +02:00
2023-08-23 01:38:43 +02:00
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
2023-11-09 18:39:08 +01:00
data : JSON . stringify ( getRemoteTokenizationParams ( str ) ) ,
2023-08-23 01:38:43 +02:00
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
2023-08-24 19:19:57 +02:00
if ( typeof data . count === 'number' ) {
tokenCount = data . count ;
} else {
tokenCount = guesstimate ( str ) ;
console . error ( "Error counting tokens" ) ;
2023-08-24 20:23:35 +02:00
if ( ! sessionStorage . getItem ( TOKENIZER _WARNING _KEY ) ) {
2023-08-24 19:19:57 +02:00
toastr . warning (
"Your selected API doesn't support the tokenization endpoint. Using estimated counts." ,
"Error counting tokens" ,
{ timeOut : 10000 , preventDuplicates : true } ,
) ;
2023-08-24 20:23:35 +02:00
sessionStorage . setItem ( TOKENIZER _WARNING _KEY , String ( true ) ) ;
2023-08-24 19:19:57 +02:00
}
}
2023-08-23 01:38:43 +02:00
}
} ) ;
2023-08-24 19:19:57 +02:00
2023-08-23 01:38:43 +02:00
return tokenCount + padding ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to encode a string to tokens .
* @ param { string } endpoint API endpoint .
* @ param { string } str String to tokenize .
2023-11-05 21:45:37 +01:00
* @ param { string } model Tokenizer model .
2023-08-27 17:27:34 +02:00
* @ returns { number [ ] } Array of token ids .
* /
2023-11-05 21:45:37 +01:00
function getTextTokensRemote ( endpoint , str , model = '' ) {
if ( model ) {
endpoint += ` ?model= ${ model } ` ;
}
2023-08-23 01:38:43 +02:00
let ids = [ ] ;
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
2023-11-09 18:39:08 +01:00
data : JSON . stringify ( getRemoteTokenizationParams ( str ) ) ,
2023-08-23 01:38:43 +02:00
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
ids = data . ids ;
2023-11-06 01:42:51 +01:00
// Don't want to break reverse compatibility, so sprinkle in some of the JS magic
if ( Array . isArray ( data . chunks ) ) {
Object . defineProperty ( ids , 'chunks' , { value : data . chunks } ) ;
}
2023-08-23 01:38:43 +02:00
}
} ) ;
return ids ;
}
2023-08-27 17:27:34 +02:00
/ * *
* Calls the underlying tokenizer model to decode token ids to text .
* @ param { string } endpoint API endpoint .
* @ param { number [ ] } ids Array of token ids
* /
2023-11-26 12:55:22 +01:00
function decodeTextTokensRemote ( endpoint , ids , model = '' ) {
if ( model ) {
endpoint += ` ?model= ${ model } ` ;
}
2023-08-27 17:27:34 +02:00
let text = '' ;
jQuery . ajax ( {
async : false ,
type : 'POST' ,
url : endpoint ,
data : JSON . stringify ( { ids : ids } ) ,
dataType : "json" ,
contentType : "application/json" ,
success : function ( data ) {
text = data . text ;
}
} ) ;
return text ;
}
/ * *
* Encodes a string to tokens using the remote server API .
* @ param { number } tokenizerType Tokenizer type .
* @ param { string } str String to tokenize .
* @ returns { number [ ] } Array of token ids .
* /
2023-08-23 01:38:43 +02:00
export function getTextTokens ( tokenizerType , str ) {
switch ( tokenizerType ) {
2023-08-27 17:27:34 +02:00
case tokenizers . GPT2 :
2023-09-16 17:48:06 +02:00
return getTextTokensRemote ( '/api/tokenize/gpt2' , str ) ;
2023-08-23 01:38:43 +02:00
case tokenizers . LLAMA :
2023-09-16 17:48:06 +02:00
return getTextTokensRemote ( '/api/tokenize/llama' , str ) ;
2023-08-23 01:38:43 +02:00
case tokenizers . NERD :
2023-09-16 17:48:06 +02:00
return getTextTokensRemote ( '/api/tokenize/nerdstash' , str ) ;
2023-08-23 01:38:43 +02:00
case tokenizers . NERD2 :
2023-09-16 17:48:06 +02:00
return getTextTokensRemote ( '/api/tokenize/nerdstash_v2' , str ) ;
2023-11-06 00:26:13 +01:00
case tokenizers . MISTRAL :
return getTextTokensRemote ( '/api/tokenize/mistral' , str ) ;
2023-11-20 23:21:58 +01:00
case tokenizers . YI :
return getTextTokensRemote ( '/api/tokenize/yi' , str ) ;
2023-11-05 21:45:37 +01:00
case tokenizers . OPENAI :
const model = getTokenizerModel ( ) ;
return getTextTokensRemote ( '/api/tokenize/openai-encode' , str , model ) ;
2023-11-09 18:39:08 +01:00
case tokenizers . API :
return getTextTokensRemote ( '/tokenize_via_api' , str ) ;
2023-08-23 01:38:43 +02:00
default :
console . warn ( "Calling getTextTokens with unsupported tokenizer type" , tokenizerType ) ;
return [ ] ;
}
}
2023-08-27 17:27:34 +02:00
/ * *
* Decodes token ids to text using the remote server API .
2023-09-16 17:48:06 +02:00
* @ param { number } tokenizerType Tokenizer type .
2023-08-27 17:27:34 +02:00
* @ param { number [ ] } ids Array of token ids
* /
export function decodeTextTokens ( tokenizerType , ids ) {
switch ( tokenizerType ) {
case tokenizers . GPT2 :
2023-09-16 17:48:06 +02:00
return decodeTextTokensRemote ( '/api/decode/gpt2' , ids ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . LLAMA :
2023-09-16 17:48:06 +02:00
return decodeTextTokensRemote ( '/api/decode/llama' , ids ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . NERD :
2023-09-16 17:48:06 +02:00
return decodeTextTokensRemote ( '/api/decode/nerdstash' , ids ) ;
2023-08-27 17:27:34 +02:00
case tokenizers . NERD2 :
2023-09-16 17:48:06 +02:00
return decodeTextTokensRemote ( '/api/decode/nerdstash_v2' , ids ) ;
2023-11-06 00:26:13 +01:00
case tokenizers . MISTRAL :
return decodeTextTokensRemote ( '/api/decode/mistral' , ids ) ;
2023-11-20 23:21:58 +01:00
case tokenizers . YI :
return decodeTextTokensRemote ( '/api/decode/yi' , ids ) ;
2023-11-26 12:55:22 +01:00
case tokenizers . OPENAI :
const model = getTokenizerModel ( ) ;
return decodeTextTokensRemote ( '/api/decode/openai' , ids , model ) ;
2023-08-27 17:27:34 +02:00
default :
console . warn ( "Calling decodeTextTokens with unsupported tokenizer type" , tokenizerType ) ;
return '' ;
}
}
2023-10-24 23:32:49 +02:00
export async function initTokenizers ( ) {
2023-08-23 01:38:43 +02:00
await loadTokenCache ( ) ;
2023-08-27 22:20:43 +02:00
registerDebugFunction ( 'resetTokenCache' , 'Reset token cache' , 'Purges the calculated token counts. Use this if you want to force a full re-tokenization of all chats or suspect the token counts are wrong.' , resetTokenCache ) ;
2023-10-24 23:32:49 +02:00
}