Add NovelAI hypebot plugin

This commit is contained in:
Cohee
2023-08-27 18:27:34 +03:00
parent 8ec9b64be4
commit 9660aaa2c2
14 changed files with 505 additions and 561 deletions

View File

@ -1,7 +1,6 @@
import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js";
import { power_user } from "./power-user.js";
import { encode } from "../lib/gpt-2-3-tokenizer/mod.js";
import { GPT3BrowserTokenizer } from "../lib/gpt-3-tokenizer/gpt3-tokenizer.js";
import { chat_completion_sources, oai_settings } from "./openai.js";
import { groups, selected_group } from "./group-chats.js";
import { getStringHash } from "./utils.js";
@ -12,7 +11,7 @@ const TOKENIZER_WARNING_KEY = 'tokenizationWarningShown';
export const tokenizers = {
NONE: 0,
GPT3: 1,
GPT2: 1,
CLASSIC: 2,
LLAMA: 3,
NERD: 4,
@ -22,7 +21,6 @@ export const tokenizers = {
};
const objectStore = new localforage.createInstance({ name: "SillyTavern_ChatCompletions" });
const gpt3 = new GPT3BrowserTokenizer({ type: 'gpt3' });
let tokenCache = {};
@ -93,6 +91,35 @@ function getTokenizerBestMatch() {
return tokenizers.NONE;
}
/**
* Calls the underlying tokenizer model to the token count for a string.
* @param {number} type Tokenizer type.
* @param {string} str String to tokenize.
* @param {number} padding Number of padding tokens.
* @returns {number} Token count.
*/
function callTokenizer(type, str, padding) {
switch (type) {
case tokenizers.NONE:
return guesstimate(str) + padding;
case tokenizers.GPT2:
return countTokensRemote('/tokenize_gpt2', str, padding);
case tokenizers.CLASSIC:
return encode(str).length + padding;
case tokenizers.LLAMA:
return countTokensRemote('/tokenize_llama', str, padding);
case tokenizers.NERD:
return countTokensRemote('/tokenize_nerdstash', str, padding);
case tokenizers.NERD2:
return countTokensRemote('/tokenize_nerdstash_v2', str, padding);
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
console.warn("Unknown tokenizer type", type);
return callTokenizer(tokenizers.NONE, str, padding);
}
}
/**
* Gets the token count for a string using the current model tokenizer.
* @param {string} str String to tokenize
@ -100,33 +127,6 @@ function getTokenizerBestMatch() {
* @returns {number} Token count.
*/
export function getTokenCount(str, padding = undefined) {
/**
* Calculates the token count for a string.
* @param {number} [type] Tokenizer type.
* @returns {number} Token count.
*/
function calculate(type) {
switch (type) {
case tokenizers.NONE:
return guesstimate(str) + padding;
case tokenizers.GPT3:
return gpt3.encode(str).bpe.length + padding;
case tokenizers.CLASSIC:
return encode(str).length + padding;
case tokenizers.LLAMA:
return countTokensRemote('/tokenize_llama', str, padding);
case tokenizers.NERD:
return countTokensRemote('/tokenize_nerdstash', str, padding);
case tokenizers.NERD2:
return countTokensRemote('/tokenize_nerdstash_v2', str, padding);
case tokenizers.API:
return countTokensRemote('/tokenize_via_api', str, padding);
default:
console.warn("Unknown tokenizer type", type);
return calculate(tokenizers.NONE);
}
}
if (typeof str !== 'string' || !str?.length) {
return 0;
}
@ -159,7 +159,7 @@ export function getTokenCount(str, padding = undefined) {
return cacheObject[cacheKey];
}
const result = calculate(tokenizerType);
const result = callTokenizer(tokenizerType, str, padding);
if (isNaN(result)) {
console.warn("Token count calculation returned NaN");
@ -350,6 +350,12 @@ function countTokensRemote(endpoint, str, padding) {
return tokenCount + padding;
}
/**
* Calls the underlying tokenizer model to encode a string to tokens.
* @param {string} endpoint API endpoint.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
function getTextTokensRemote(endpoint, str) {
let ids = [];
jQuery.ajax({
@ -366,8 +372,37 @@ function getTextTokensRemote(endpoint, str) {
return ids;
}
/**
* Calls the underlying tokenizer model to decode token ids to text.
* @param {string} endpoint API endpoint.
* @param {number[]} ids Array of token ids
*/
function decodeTextTokensRemote(endpoint, ids) {
let text = '';
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify({ ids: ids }),
dataType: "json",
contentType: "application/json",
success: function (data) {
text = data.text;
}
});
return text;
}
/**
* Encodes a string to tokens using the remote server API.
* @param {number} tokenizerType Tokenizer type.
* @param {string} str String to tokenize.
* @returns {number[]} Array of token ids.
*/
export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) {
case tokenizers.GPT2:
return getTextTokensRemote('/tokenize_gpt2', str);
case tokenizers.LLAMA:
return getTextTokensRemote('/tokenize_llama', str);
case tokenizers.NERD:
@ -380,6 +415,27 @@ export function getTextTokens(tokenizerType, str) {
}
}
/**
* Decodes token ids to text using the remote server API.
* @param {any} tokenizerType Tokenizer type.
* @param {number[]} ids Array of token ids
*/
export function decodeTextTokens(tokenizerType, ids) {
switch (tokenizerType) {
case tokenizers.GPT2:
return decodeTextTokensRemote('/decode_gpt2', ids);
case tokenizers.LLAMA:
return decodeTextTokensRemote('/decode_llama', ids);
case tokenizers.NERD:
return decodeTextTokensRemote('/decode_nerdstash', ids);
case tokenizers.NERD2:
return decodeTextTokensRemote('/decode_nerdstash_v2', ids);
default:
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
return '';
}
}
jQuery(async () => {
await loadTokenCache();
});