Add text trimming commands

This commit is contained in:
Cohee 2023-11-26 13:55:22 +02:00
parent 7b3f2a8986
commit e6c96553d0
3 changed files with 113 additions and 3 deletions

View File

@ -28,6 +28,7 @@ import {
callPopup,
deactivateSendButtons,
activateSendButtons,
main_api,
} from "../script.js";
import { getMessageTimeStamp } from "./RossAscends-mods.js";
import { findGroupMemberId, groups, is_group_generating, resetSelectedGroup, saveGroupChat, selected_group } from "./group-chats.js";
@ -36,8 +37,9 @@ import { addEphemeralStoppingString, chat_styles, flushEphemeralStoppingStrings,
import { autoSelectPersona } from "./personas.js";
import { getContext } from "./extensions.js";
import { hideChatMessage, unhideChatMessage } from "./chats.js";
import { delay, isFalseBoolean, isTrueBoolean, stringToRange } from "./utils.js";
import { delay, isFalseBoolean, isTrueBoolean, stringToRange, trimToEndSentence, trimToStartSentence } from "./utils.js";
import { registerVariableCommands, resolveVariable } from "./variables.js";
import { decodeTextTokens, getFriendlyTokenizerName, getTextTokens, getTokenCount } from "./tokenizers.js";
export {
executeSlashCommands,
registerSlashCommand,
@ -175,6 +177,9 @@ parser.addCommand('run', runCallback, ['call', 'exec'], '<span class="monospace"
parser.addCommand('messages', getMessagesCallback, ['message'], '<span class="monospace">(names=off/on [message index or range])</span> returns the specified message or range of messages as a string.', true, true);
parser.addCommand('setinput', setInputCallback, [], '<span class="monospace">(text)</span> sets the user input to the specified text and passes it to the next command through the pipe.', true, true);
parser.addCommand('popup', popupCallback, [], '<span class="monospace">(text)</span> shows a blocking popup with the specified text.', true, true);
parser.addCommand('trimtokens', trimTokensCallback, [], '<span class="monospace">(direction=start/end limit=number [text])</span> trims the start or end of text to the specified number of tokens.', true, true);
parser.addCommand('trimstart', trimStartCallback, [], '<span class="monospace">(text)</span> trims the text to the start of the first full sentence.', true, true);
parser.addCommand('trimend', trimEndCallback, [], '<span class="monospace">(text)</span> trims the text to the end of the last full sentence.', true, true);
registerVariableCommands();
const NARRATOR_NAME_KEY = 'narrator_name';
@ -186,6 +191,70 @@ function setInputCallback(_, value) {
return value;
}
function trimStartCallback(_, value) {
if (!value) {
return '';
}
return trimToStartSentence(value);
}
function trimEndCallback(_, value) {
if (!value) {
return '';
}
return trimToEndSentence(value);
}
function trimTokensCallback(arg, value) {
if (!value) {
console.warn('WARN: No argument provided for /trimtokens command');
return '';
}
const limit = Number(resolveVariable(arg.limit));
if (isNaN(limit)) {
console.warn(`WARN: Invalid limit provided for /trimtokens command: ${limit}`);
return value;
}
if (limit <= 0) {
return '';
}
const direction = arg.direction || 'end';
const tokenCount = getTokenCount(value)
// Token count is less than the limit, do nothing
if (tokenCount <= limit) {
return value;
}
const { tokenizerName, tokenizerId } = getFriendlyTokenizerName(main_api);
console.debug('Requesting tokenization for /trimtokens command', tokenizerName);
try {
const textTokens = getTextTokens(tokenizerId, value);
if (!Array.isArray(textTokens) || !textTokens.length) {
console.warn('WARN: No tokens returned for /trimtokens command, falling back to estimation');
const percentage = limit / tokenCount;
const trimIndex = Math.floor(value.length * percentage);
const trimmedText = direction === 'start' ? value.substring(trimIndex) : value.substring(0, value.length - trimIndex);
return trimmedText;
}
const sliceTokens = direction === 'start' ? textTokens.slice(0, limit) : textTokens.slice(-limit);
const decodedText = decodeTextTokens(tokenizerId, sliceTokens);
return decodedText;
} catch (error) {
console.warn('WARN: Tokenization failed for /trimtokens command, returning original', error);
return value;
}
}
async function popupCallback(_, value) {
const safeValue = DOMPurify.sanitize(value || '');
await delay(1);

View File

@ -467,7 +467,11 @@ function getTextTokensRemote(endpoint, str, model = '') {
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
*/
function decodeTextTokensRemote(endpoint, ids) {
function decodeTextTokensRemote(endpoint, ids, model = '') {
if (model) {
endpoint += `?model=${model}`;
}
let text = '';
jQuery.ajax({
async: false,
@ -533,6 +537,9 @@ export function decodeTextTokens(tokenizerType, ids) {
return decodeTextTokensRemote('/api/decode/mistral', ids);
case tokenizers.YI:
return decodeTextTokensRemote('/api/decode/yi', ids);
case tokenizers.OPENAI:
const model = getTokenizerModel();
return decodeTextTokensRemote('/api/decode/openai', ids, model);
default:
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
return '';

View File

@ -245,7 +245,7 @@ async function loadClaudeTokenizer(modelPath) {
}
function countClaudeTokens(tokenizer, messages) {
const convertedPrompt = convertClaudePrompt(messages, false, false);
const convertedPrompt = convertClaudePrompt(messages, false, false, false);
// Fallback to strlen estimation
if (!tokenizer) {
@ -435,6 +435,40 @@ function registerEndpoints(app, jsonParser) {
}
});
app.post('/api/decode/openai', jsonParser, async function (req, res) {
try {
const queryModel = String(req.query.model || '');
if (queryModel.includes('llama')) {
const handler = createSentencepieceDecodingHandler(spp_llama);
return handler(req, res);
}
if (queryModel.includes('mistral')) {
const handler = createSentencepieceDecodingHandler(spp_mistral);
return handler(req, res);
}
if (queryModel.includes('yi')) {
const handler = createSentencepieceDecodingHandler(spp_yi);
return handler(req, res);
}
if (queryModel.includes('claude')) {
const ids = req.body.ids || [];
const chunkText = await claude_tokenizer.decode(new Uint32Array(ids));
return res.send({ text: chunkText });
}
const model = getTokenizerModel(queryModel);
const handler = createTiktokenDecodingHandler(model);
return handler(req, res);
} catch (error) {
console.log(error);
return res.send({ text: '' });
}
});
app.post("/api/tokenize/openai", jsonParser, async function (req, res) {
try {
if (!req.body) return res.sendStatus(400);