mirror of
				https://github.com/SillyTavern/SillyTavern.git
				synced 2025-06-05 21:59:27 +02:00 
			
		
		
		
	Add caching of OAI messages tokens
This commit is contained in:
		| @@ -18,6 +18,8 @@ import { | ||||
|     name1, | ||||
|     name2, | ||||
|     extension_prompt_types, | ||||
|     characters, | ||||
|     this_chid, | ||||
| } from "../script.js"; | ||||
| import { groups, selected_group } from "./group-chats.js"; | ||||
|  | ||||
| @@ -25,6 +27,10 @@ import { | ||||
|     pin_examples, | ||||
| } from "./power-user.js"; | ||||
|  | ||||
| import { | ||||
|     getStringHash, | ||||
| } from "./utils.js"; | ||||
|  | ||||
| export { | ||||
|     is_get_status_openai, | ||||
|     openai_msgs, | ||||
| @@ -50,6 +56,8 @@ const default_nsfw_prompt = "NSFW/Smut is allowed. Assume consent was granted, b | ||||
| const gpt3_max = 4095; | ||||
| const gpt4_max = 8191; | ||||
|  | ||||
| const tokenCache = {}; | ||||
|  | ||||
| const oai_settings = { | ||||
|     preset_settings_openai: 'Default', | ||||
|     api_key_openai: '', | ||||
| @@ -101,7 +109,7 @@ function setOpenAIMessages(chat) { | ||||
|  | ||||
|         // replace bias markup | ||||
|         content = (content ?? '').replace(/{.*}/g, ''); | ||||
|      | ||||
|  | ||||
|         // Apply the "wrap in quotes" option | ||||
|         if (role == 'user' && oai_settings.wrap_in_quotes) content = `"${content}"`; | ||||
|         openai_msgs[i] = { "role": role, "content": content }; | ||||
| @@ -249,15 +257,15 @@ async function prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldI | ||||
|  | ||||
|     // todo: static value, maybe include in the initial context calculation | ||||
|     let new_chat_msg = { "role": "system", "content": "[Start a new chat]" }; | ||||
|     let start_chat_count = await countTokens([new_chat_msg]); | ||||
|     let total_count = await countTokens([prompt_msg], true) + start_chat_count; | ||||
|     let start_chat_count = countTokens([new_chat_msg]); | ||||
|     let total_count = countTokens([prompt_msg], true) + start_chat_count; | ||||
|  | ||||
|     if (bias && bias.trim().length) { | ||||
|         let bias_msg = { "role": "system", "content": bias.trim() }; | ||||
|         openai_msgs.push(bias_msg); | ||||
|         total_count += await countTokens([bias_msg], true); | ||||
|         total_count += countTokens([bias_msg], true); | ||||
|     } | ||||
|      | ||||
|  | ||||
|     if (selected_group) { | ||||
|         // set "special" group nudging messages | ||||
|         const groupMembers = groups.find(x => x.id === selected_group)?.members; | ||||
| @@ -267,20 +275,20 @@ async function prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldI | ||||
|         openai_msgs.push(group_nudge); | ||||
|  | ||||
|         // add a group nudge count | ||||
|         let group_nudge_count = await countTokens([group_nudge], true); | ||||
|         let group_nudge_count = countTokens([group_nudge], true); | ||||
|         total_count += group_nudge_count; | ||||
|          | ||||
|  | ||||
|         // recount tokens for new start message | ||||
|         total_count -= start_chat_count | ||||
|         start_chat_count = await countTokens([new_chat_msg]); | ||||
|         start_chat_count = countTokens([new_chat_msg]); | ||||
|         total_count += start_chat_count; | ||||
|     } | ||||
|  | ||||
|     if (oai_settings.jailbreak_system) { | ||||
|         const jailbreakMessage = { "role": "system", "content": `[System note: ${oai_settings.nsfw_prompt}]`}; | ||||
|         const jailbreakMessage = { "role": "system", "content": `[System note: ${oai_settings.nsfw_prompt}]` }; | ||||
|         openai_msgs.push(jailbreakMessage); | ||||
|  | ||||
|         total_count += await countTokens([jailbreakMessage], true); | ||||
|         total_count += countTokens([jailbreakMessage], true); | ||||
|     } | ||||
|  | ||||
|     // The user wants to always have all example messages in the context | ||||
| @@ -302,11 +310,11 @@ async function prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldI | ||||
|                 examples_tosend.push(example); | ||||
|             } | ||||
|         } | ||||
|         total_count += await countTokens(examples_tosend); | ||||
|         total_count += countTokens(examples_tosend); | ||||
|         // go from newest message to oldest, because we want to delete the older ones from the context | ||||
|         for (let j = openai_msgs.length - 1; j >= 0; j--) { | ||||
|             let item = openai_msgs[j]; | ||||
|             let item_count = await countTokens(item); | ||||
|             let item_count = countTokens(item); | ||||
|             // If we have enough space for this message, also account for the max assistant reply size | ||||
|             if ((total_count + item_count) < (this_max_context - oai_settings.openai_max_tokens)) { | ||||
|                 openai_msgs_tosend.push(item); | ||||
| @@ -320,7 +328,7 @@ async function prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldI | ||||
|     } else { | ||||
|         for (let j = openai_msgs.length - 1; j >= 0; j--) { | ||||
|             let item = openai_msgs[j]; | ||||
|             let item_count = await countTokens(item); | ||||
|             let item_count = countTokens(item); | ||||
|             // If we have enough space for this message, also account for the max assistant reply size | ||||
|             if ((total_count + item_count) < (this_max_context - oai_settings.openai_max_tokens)) { | ||||
|                 openai_msgs_tosend.push(item); | ||||
| @@ -340,7 +348,7 @@ async function prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldI | ||||
|  | ||||
|             for (let k = 0; k < example_block.length; k++) { | ||||
|                 if (example_block.length == 0) { continue; } | ||||
|                 let example_count = await countTokens(example_block[k]); | ||||
|                 let example_count = countTokens(example_block[k]); | ||||
|                 // add all the messages from the example | ||||
|                 if ((total_count + example_count + start_chat_count) < (this_max_context - oai_settings.openai_max_tokens)) { | ||||
|                     if (k == 0) { | ||||
| @@ -448,26 +456,45 @@ function onStream(e, resolve, reject, last_view_mes) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| async function countTokens(messages, full = false) { | ||||
|         return new Promise((resolve) => { | ||||
|             if (!Array.isArray(messages)) { | ||||
|                 messages = [messages]; | ||||
|             } | ||||
|             let token_count = -1; | ||||
| function countTokens(messages, full = false) { | ||||
|     let chatId = selected_group ? selected_group : characters[this_chid].chat; | ||||
|  | ||||
|     if (typeof tokenCache[chatId] !== 'object') { | ||||
|         tokenCache[chatId] = {}; | ||||
|     } | ||||
|  | ||||
|     if (!Array.isArray(messages)) { | ||||
|         messages = [messages]; | ||||
|     } | ||||
|  | ||||
|     let token_count = -1; | ||||
|  | ||||
|     for (const message of messages) { | ||||
|         const hash = getStringHash(message.content); | ||||
|         const cachedCount = tokenCache[chatId][hash]; | ||||
|  | ||||
|         if (cachedCount) { | ||||
|             token_count += cachedCount; | ||||
|         } | ||||
|         else { | ||||
|             jQuery.ajax({ | ||||
|                 async: true, | ||||
|                 async: false, | ||||
|                 type: 'POST', //  | ||||
|                 url: `/tokenize_openai?model=${oai_settings.openai_model}`, | ||||
|                 data: JSON.stringify(messages), | ||||
|                 data: JSON.stringify([message]), | ||||
|                 dataType: "json", | ||||
|                 contentType: "application/json", | ||||
|                 success: function (data) { | ||||
|                     token_count = data.token_count; | ||||
|                     if (!full) token_count -= 2; | ||||
|                     resolve(token_count); | ||||
|                     token_count += data.token_count; | ||||
|                     tokenCache[chatId][hash] = data.token_count; | ||||
|                 } | ||||
|             }); | ||||
|         }); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (!full) token_count -= 2; | ||||
|  | ||||
|     return token_count; | ||||
| } | ||||
|  | ||||
| function loadOpenAISettings(data, settings) { | ||||
| @@ -607,7 +634,7 @@ $(document).ready(function () { | ||||
|         saveSettingsDebounced(); | ||||
|     }); | ||||
|  | ||||
|     $("#model_openai_select").change(function() { | ||||
|     $("#model_openai_select").change(function () { | ||||
|         const value = $(this).val(); | ||||
|         oai_settings.openai_model = value; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user