Add OpenAI vector source.

This commit is contained in:
Cohee 2023-09-08 13:57:27 +03:00
parent 02bdd56e20
commit a5acc7872d
8 changed files with 345 additions and 159 deletions

View File

@ -6,20 +6,21 @@ import { debounce, getStringHash as calculateHash } from "../../utils.js";
const MODULE_NAME = 'vectors';
const AMOUNT_TO_LEAVE = 5;
const INSERT_AMOUNT = 3;
const QUERY_TEXT_AMOUNT = 3;
const QUERY_TEXT_AMOUNT = 2;
export const EXTENSION_PROMPT_TAG = '3_vectors';
const settings = {
enabled: false,
source: 'local',
};
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
async function synchronizeChat() {
async function synchronizeChat(batchSize = 10) {
try {
if (!settings.enabled) {
return;
return -1;
}
const context = getContext();
@ -37,7 +38,7 @@ async function synchronizeChat() {
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
if (newVectorItems.length > 0) {
await insertVectorItems(chatId, newVectorItems);
await insertVectorItems(chatId, newVectorItems.slice(0, batchSize));
console.log(`Vectors: Inserted ${newVectorItems.length} new items`);
}
@ -45,6 +46,8 @@ async function synchronizeChat() {
await deleteVectorItems(chatId, deletedHashes);
console.log(`Vectors: Deleted ${deletedHashes.length} old hashes`);
}
return newVectorItems.length - batchSize;
} catch (error) {
console.error('Vectors: Failed to synchronize chat', error);
}
@ -59,18 +62,18 @@ const hashCache = {};
* @returns {number} Hash value
*/
function getStringHash(str) {
// Check if the hash is already in the cache
if (hashCache.hasOwnProperty(str)) {
return hashCache[str];
}
// Check if the hash is already in the cache
if (hashCache.hasOwnProperty(str)) {
return hashCache[str];
}
// Calculate the hash value
const hash = calculateHash(str);
// Calculate the hash value
const hash = calculateHash(str);
// Store the hash in the cache
hashCache[str] = hash;
// Store the hash in the cache
hashCache[str] = hash;
return hash;
return hash;
}
/**
@ -79,6 +82,9 @@ function getStringHash(str) {
*/
async function rearrangeChat(chat) {
try {
// Clear the extension prompt
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0);
if (!settings.enabled) {
return;
}
@ -127,6 +133,11 @@ async function rearrangeChat(chat) {
}
}
if (queriedMessages.length === 0) {
console.debug('Vectors: No relevant messages found');
return;
}
// Format queried messages into a single string
const queriedText = 'Past events: ' + queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n');
setExtensionPrompt(EXTENSION_PROMPT_TAG, queriedText, extension_prompt_types.IN_PROMPT, 0);
@ -171,7 +182,10 @@ async function getSavedHashes(collectionId) {
const response = await fetch('/api/vector/list', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ collectionId }),
body: JSON.stringify({
collectionId: collectionId,
source: settings.source,
}),
});
if (!response.ok) {
@ -192,7 +206,11 @@ async function insertVectorItems(collectionId, items) {
const response = await fetch('/api/vector/insert', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ collectionId, items }),
body: JSON.stringify({
collectionId: collectionId,
items: items,
source: settings.source,
}),
});
if (!response.ok) {
@ -210,7 +228,11 @@ async function deleteVectorItems(collectionId, hashes) {
const response = await fetch('/api/vector/delete', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ collectionId, hashes }),
body: JSON.stringify({
collectionId: collectionId,
hashes: hashes,
source: settings.source,
}),
});
if (!response.ok) {
@ -228,7 +250,12 @@ async function queryCollection(collectionId, searchText, topK) {
const response = await fetch('/api/vector/query', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ collectionId, searchText, topK }),
body: JSON.stringify({
collectionId: collectionId,
searchText: searchText,
topK: topK,
source: settings.source,
}),
});
if (!response.ok) {
@ -251,6 +278,11 @@ jQuery(async () => {
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
$('#vectors_source').val(settings.source).on('change', () => {
settings.source = String($('#vectors_source').val());
Object.assign(extension_settings.vectors, settings);
saveSettingsDebounced();
});
eventSource.on(event_types.CHAT_CHANGED, onChatEvent);
eventSource.on(event_types.MESSAGE_DELETED, onChatEvent);

View File

@ -9,6 +9,13 @@
<input id="vectors_enabled" type="checkbox" class="checkbox">
Enabled
</label>
<label for="vectors_source">
Source
</label>
<select id="vectors_source" class="select">
<option value="local">Local</option>
<option value="openai">OpenAI</option>
</select>
</div>
</div>
</div>

View File

@ -629,6 +629,12 @@ function populateChatCompletion(prompts, chatCompletion, { bias, quietPrompt, ty
if (true === afterScenario) chatCompletion.insert(authorsNote, 'scenario');
}
// Vectors Memory
if (prompts.has('vectorsMemory')) {
const vectorsMemory = Message.fromPrompt(prompts.get('vectorsMemory'));
chatCompletion.insert(vectorsMemory, 'main');
}
// Decide whether dialogue examples should always be added
if (power_user.pin_examples) {
populateDialogueExamples(prompts, chatCompletion);

120
server.js
View File

@ -68,6 +68,7 @@ const characterCardParser = require('./src/character-card-parser.js');
const contentManager = require('./src/content-manager');
const novelai = require('./src/novelai');
const statsHelpers = require('./statsHelpers.js');
const { writeSecret, readSecret, readSecretState, migrateSecrets, SECRET_KEYS, getAllSecrets } = require('./src/secrets');
function createDefaultFiles() {
const files = {
@ -325,6 +326,7 @@ function humanizedISO8601DateTime(date) {
var charactersPath = 'public/characters/';
var chatsPath = 'public/chats/';
const UPLOADS_PATH = './uploads';
const SETTINGS_FILE = './public/settings.json';
const AVATAR_WIDTH = 400;
const AVATAR_HEIGHT = 600;
const jsonParser = express.json({ limit: '100mb' });
@ -4070,7 +4072,7 @@ const setupTasks = async function () {
console.log(`SillyTavern ${version.pkgVersion}` + (version.gitBranch ? ` '${version.gitBranch}' (${version.gitRevision})` : ''));
backupSettings();
migrateSecrets();
migrateSecrets(SETTINGS_FILE);
ensurePublicDirectoriesExist();
await ensureThumbnailCache();
contentManager.checkForNewContent();
@ -4221,69 +4223,6 @@ function ensurePublicDirectoriesExist() {
}
}
const SECRETS_FILE = './secrets.json';
const SETTINGS_FILE = './public/settings.json';
const SECRET_KEYS = {
HORDE: 'api_key_horde',
MANCER: 'api_key_mancer',
OPENAI: 'api_key_openai',
NOVEL: 'api_key_novel',
CLAUDE: 'api_key_claude',
DEEPL: 'deepl',
LIBRE: 'libre',
LIBRE_URL: 'libre_url',
OPENROUTER: 'api_key_openrouter',
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
SCALE_COOKIE: 'scale_cookie',
}
function migrateSecrets() {
if (!fs.existsSync(SETTINGS_FILE)) {
console.log('Settings file does not exist');
return;
}
try {
let modified = false;
const fileContents = fs.readFileSync(SETTINGS_FILE, 'utf8');
const settings = JSON.parse(fileContents);
const oaiKey = settings?.api_key_openai;
const hordeKey = settings?.horde_settings?.api_key;
const novelKey = settings?.api_key_novel;
if (typeof oaiKey === 'string') {
console.log('Migrating OpenAI key...');
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
delete settings.api_key_openai;
modified = true;
}
if (typeof hordeKey === 'string') {
console.log('Migrating Horde key...');
writeSecret(SECRET_KEYS.HORDE, hordeKey);
delete settings.horde_settings.api_key;
modified = true;
}
if (typeof novelKey === 'string') {
console.log('Migrating Novel key...');
writeSecret(SECRET_KEYS.NOVEL, novelKey);
delete settings.api_key_novel;
modified = true;
}
if (modified) {
console.log('Writing updated settings.json...');
const settingsContent = JSON.stringify(settings);
writeFileAtomicSync(SETTINGS_FILE, settingsContent, "utf-8");
}
}
catch (error) {
console.error('Could not migrate secrets file. Proceed with caution.');
}
}
app.post('/writesecret', jsonParser, (request, response) => {
const key = request.body.key;
const value = request.body.value;
@ -4293,19 +4232,9 @@ app.post('/writesecret', jsonParser, (request, response) => {
});
app.post('/readsecretstate', jsonParser, (_, response) => {
if (!fs.existsSync(SECRETS_FILE)) {
return response.send({});
}
try {
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
const secrets = JSON.parse(fileContents);
const state = {};
for (const key of Object.values(SECRET_KEYS)) {
state[key] = !!secrets[key]; // convert to boolean
}
const state = readSecretState();
return response.send(state);
} catch (error) {
console.error(error);
@ -4351,14 +4280,13 @@ app.post('/viewsecrets', jsonParser, async (_, response) => {
return response.sendStatus(403);
}
if (!fs.existsSync(SECRETS_FILE)) {
console.error('secrets.json does not exist');
return response.sendStatus(404);
}
try {
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
const secrets = getAllSecrets();
if (!secrets) {
return response.sendStatus(404);
}
return response.send(secrets);
} catch (error) {
console.error(error);
@ -4797,6 +4725,11 @@ app.post('/libre_translate', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.LIBRE);
const url = readSecret(SECRET_KEYS.LIBRE_URL);
if (!url) {
console.log('LibreTranslate URL is not configured.');
return response.sendStatus(401);
}
const text = request.body.text;
const lang = request.body.lang;
@ -5292,27 +5225,6 @@ function importRisuSprites(data) {
}
}
function writeSecret(key, value) {
if (!fs.existsSync(SECRETS_FILE)) {
const emptyFile = JSON.stringify({});
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
secrets[key] = value;
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
}
function readSecret(key) {
if (!fs.existsSync(SECRETS_FILE)) {
return undefined;
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
return secrets[key];
}
async function readAllChunks(readableStream) {
return new Promise((resolve, reject) => {
@ -5385,8 +5297,6 @@ async function getImageBuffers(zipFilePath) {
});
}
/**
* This function extracts the extension information from the manifest file.
* @param {string} extensionPath - The path of the extension folder

38
src/local-vectors.js Normal file
View File

@ -0,0 +1,38 @@
require('@tensorflow/tfjs');
const encoder = require('@tensorflow-models/universal-sentence-encoder');
/**
* Lazy loading class for the embedding model.
*/
class EmbeddingModel {
/**
* @type {encoder.UniversalSentenceEncoder} - The embedding model
*/
model;
async get() {
if (!this.model) {
this.model = await encoder.load();
}
return this.model;
}
}
const model = new EmbeddingModel();
/**
* @param {string} text
*/
async function getLocalVector(text) {
const use = await model.get();
const tensor = await use.embed(text);
const vector = Array.from(await tensor.data());
return vector;
}
module.exports = {
getLocalVector,
};

47
src/openai-vectors.js Normal file
View File

@ -0,0 +1,47 @@
const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./secrets');
/**
* Gets the vector for the given text from OpenAI ada model
* @param {string} text - The text to get the vector for
* @returns {Promise<number[]>} - The vector for the text
*/
async function getOpenAIVector(text) {
const key = readSecret(SECRET_KEYS.OPENAI);
if (!key) {
console.log('No OpenAI key found');
throw new Error('No OpenAI key found');
}
const response = await fetch('https://api.openai.com/v1/embeddings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${key}`,
},
body: JSON.stringify({
input: text,
model: 'text-embedding-ada-002',
})
});
if (!response.ok) {
console.log('OpenAI request failed');
throw new Error('OpenAI request failed');
}
const data = await response.json();
const vector = data?.data[0]?.embedding;
if (!Array.isArray(vector)) {
console.log('OpenAI response was not an array');
throw new Error('OpenAI response was not an array');
}
return vector;
}
module.exports = {
getOpenAIVector,
};

146
src/secrets.js Normal file
View File

@ -0,0 +1,146 @@
const fs = require('fs');
const path = require('path');
const writeFileAtomicSync = require('write-file-atomic').sync;
const SECRETS_FILE = path.join(process.cwd(), './secrets.json');
const SECRET_KEYS = {
HORDE: 'api_key_horde',
MANCER: 'api_key_mancer',
OPENAI: 'api_key_openai',
NOVEL: 'api_key_novel',
CLAUDE: 'api_key_claude',
DEEPL: 'deepl',
LIBRE: 'libre',
LIBRE_URL: 'libre_url',
OPENROUTER: 'api_key_openrouter',
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
SCALE_COOKIE: 'scale_cookie',
}
/**
* Writes a secret to the secrets file
* @param {string} key Secret key
* @param {string} value Secret value
*/
function writeSecret(key, value) {
if (!fs.existsSync(SECRETS_FILE)) {
const emptyFile = JSON.stringify({});
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
secrets[key] = value;
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
}
/**
* Reads a secret from the secrets file
* @param {string} key Secret key
* @returns {string} Secret value
*/
function readSecret(key) {
if (!fs.existsSync(SECRETS_FILE)) {
return '';
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
const secrets = JSON.parse(fileContents);
return secrets[key];
}
/**
* Reads the secret state from the secrets file
* @returns {object} Secret state
*/
function readSecretState() {
if (!fs.existsSync(SECRETS_FILE)) {
return {};
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
const secrets = JSON.parse(fileContents);
const state = {};
for (const key of Object.values(SECRET_KEYS)) {
state[key] = !!secrets[key]; // convert to boolean
}
return state;
}
/**
* Migrates secrets from settings.json to secrets.json
* @param {string} settingsFile Path to settings.json
* @returns {void}
*/
function migrateSecrets(settingsFile) {
if (!fs.existsSync(settingsFile)) {
console.log('Settings file does not exist');
return;
}
try {
let modified = false;
const fileContents = fs.readFileSync(settingsFile, 'utf8');
const settings = JSON.parse(fileContents);
const oaiKey = settings?.api_key_openai;
const hordeKey = settings?.horde_settings?.api_key;
const novelKey = settings?.api_key_novel;
if (typeof oaiKey === 'string') {
console.log('Migrating OpenAI key...');
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
delete settings.api_key_openai;
modified = true;
}
if (typeof hordeKey === 'string') {
console.log('Migrating Horde key...');
writeSecret(SECRET_KEYS.HORDE, hordeKey);
delete settings.horde_settings.api_key;
modified = true;
}
if (typeof novelKey === 'string') {
console.log('Migrating Novel key...');
writeSecret(SECRET_KEYS.NOVEL, novelKey);
delete settings.api_key_novel;
modified = true;
}
if (modified) {
console.log('Writing updated settings.json...');
const settingsContent = JSON.stringify(settings);
writeFileAtomicSync(settingsFile, settingsContent, "utf-8");
}
}
catch (error) {
console.error('Could not migrate secrets file. Proceed with caution.');
}
}
/**
* Reads all secrets from the secrets file
* @returns {Record<string, string> | undefined} Secrets
*/
function getAllSecrets() {
if (!fs.existsSync(SECRETS_FILE)) {
console.log('Secrets file does not exist');
return undefined;
}
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
const secrets = JSON.parse(fileContents);
return secrets;
}
module.exports = {
writeSecret,
readSecret,
readSecretState,
migrateSecrets,
getAllSecrets,
SECRET_KEYS,
};

View File

@ -2,36 +2,32 @@ const express = require('express');
const vectra = require('vectra');
const path = require('path');
const sanitize = require('sanitize-filename');
require('@tensorflow/tfjs');
const encoder = require('@tensorflow-models/universal-sentence-encoder');
/**
* Lazy loading class for the embedding model.
* Gets the vector for the given text from the given source.
* @param {string} source - The source of the vector
* @param {string} text - The text to get the vector for
* @returns {Promise<number[]>} - The vector for the text
*/
class EmbeddingModel {
/**
* @type {encoder.UniversalSentenceEncoder} - The embedding model
*/
model;
async get() {
if (!this.model) {
this.model = await encoder.load();
}
return this.model;
async function getVector(source, text) {
switch (source) {
case 'local':
return require('./local-vectors').getLocalVector(text);
case 'openai':
return require('./openai-vectors').getOpenAIVector(text);
}
}
const model = new EmbeddingModel();
throw new Error(`Unknown vector source ${source}`);
}
/**
* Gets the index for the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
*/
async function getIndex(collectionId) {
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(collectionId)));
async function getIndex(collectionId, source) {
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
if (!await index.isIndexCreated()) {
await index.createIndex();
@ -43,19 +39,18 @@ async function getIndex(collectionId) {
/**
* Inserts items into the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {{ hash: number; text: string; }[]} items - The items to insert
*/
async function insertVectorItems(collectionId, items) {
const index = await getIndex(collectionId);
const use = await model.get();
async function insertVectorItems(collectionId, source, items) {
const index = await getIndex(collectionId, source);
await index.beginUpdate();
for (const item of items) {
const text = item.text;
const hash = item.hash;
const tensor = await use.embed(text);
const vector = Array.from(await tensor.data());
const vector = await getVector(source, text);
await index.upsertItem({ vector: vector, metadata: { hash, text } });
}
@ -65,10 +60,11 @@ async function insertVectorItems(collectionId, items) {
/**
* Gets the hashes of the items in the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @returns {Promise<number[]>} - The hashes of the items in the collection
*/
async function getSavedHashes(collectionId) {
const index = await getIndex(collectionId);
async function getSavedHashes(collectionId, source) {
const index = await getIndex(collectionId, source);
const items = await index.listItems();
const hashes = items.map(x => Number(x.metadata.hash));
@ -79,10 +75,11 @@ async function getSavedHashes(collectionId) {
/**
* Deletes items from the vector collection by hash
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {number[]} hashes - The hashes of the items to delete
*/
async function deleteVectorItems(collectionId, hashes) {
const index = await getIndex(collectionId);
async function deleteVectorItems(collectionId, source, hashes) {
const index = await getIndex(collectionId, source);
const items = await index.listItemsByMetadata({ hash: { '$in': hashes } });
await index.beginUpdate();
@ -97,15 +94,14 @@ async function deleteVectorItems(collectionId, hashes) {
/**
* Gets the hashes of the items in the vector collection that match the search text
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @returns {Promise<number[]>} - The hashes of the items that match the search text
*/
async function queryCollection(collectionId, searchText, topK) {
const index = await getIndex(collectionId);
const use = await model.get();
const tensor = await use.embed(searchText);
const vector = Array.from(await tensor.data());
async function queryCollection(collectionId, source, searchText, topK) {
const index = await getIndex(collectionId, source);
const vector = await getVector(source, searchText);
const result = await index.queryItems(vector, topK);
const hashes = result.map(x => Number(x.item.metadata.hash));
@ -127,8 +123,9 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const source = String(req.body.source) || 'local';
const results = await queryCollection(collectionId, searchText, topK);
const results = await queryCollection(collectionId, source, searchText, topK);
return res.json(results);
} catch (error) {
console.error(error);
@ -144,8 +141,9 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
const source = String(req.body.source) || 'local';
await insertVectorItems(collectionId, items);
await insertVectorItems(collectionId, source, items);
return res.sendStatus(200);
} catch (error) {
console.error(error);
@ -160,8 +158,9 @@ async function registerEndpoints(app, jsonParser) {
}
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'local';
const hashes = await getSavedHashes(collectionId);
const hashes = await getSavedHashes(collectionId, source);
return res.json(hashes);
} catch (error) {
console.error(error);
@ -177,8 +176,9 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const hashes = req.body.hashes.map(x => Number(x));
const source = String(req.body.source) || 'local';
await deleteVectorItems(collectionId, hashes);
await deleteVectorItems(collectionId, source, hashes);
return res.sendStatus(200);
} catch (error) {
console.error(error);