Add OpenAI vector source.
This commit is contained in:
parent
02bdd56e20
commit
a5acc7872d
|
@ -6,20 +6,21 @@ import { debounce, getStringHash as calculateHash } from "../../utils.js";
|
|||
const MODULE_NAME = 'vectors';
|
||||
const AMOUNT_TO_LEAVE = 5;
|
||||
const INSERT_AMOUNT = 3;
|
||||
const QUERY_TEXT_AMOUNT = 3;
|
||||
const QUERY_TEXT_AMOUNT = 2;
|
||||
|
||||
export const EXTENSION_PROMPT_TAG = '3_vectors';
|
||||
|
||||
const settings = {
|
||||
enabled: false,
|
||||
source: 'local',
|
||||
};
|
||||
|
||||
const moduleWorker = new ModuleWorkerWrapper(synchronizeChat);
|
||||
|
||||
async function synchronizeChat() {
|
||||
async function synchronizeChat(batchSize = 10) {
|
||||
try {
|
||||
if (!settings.enabled) {
|
||||
return;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const context = getContext();
|
||||
|
@ -37,7 +38,7 @@ async function synchronizeChat() {
|
|||
const deletedHashes = hashesInCollection.filter(x => !hashedMessages.some(y => y.hash === x));
|
||||
|
||||
if (newVectorItems.length > 0) {
|
||||
await insertVectorItems(chatId, newVectorItems);
|
||||
await insertVectorItems(chatId, newVectorItems.slice(0, batchSize));
|
||||
console.log(`Vectors: Inserted ${newVectorItems.length} new items`);
|
||||
}
|
||||
|
||||
|
@ -45,6 +46,8 @@ async function synchronizeChat() {
|
|||
await deleteVectorItems(chatId, deletedHashes);
|
||||
console.log(`Vectors: Deleted ${deletedHashes.length} old hashes`);
|
||||
}
|
||||
|
||||
return newVectorItems.length - batchSize;
|
||||
} catch (error) {
|
||||
console.error('Vectors: Failed to synchronize chat', error);
|
||||
}
|
||||
|
@ -59,18 +62,18 @@ const hashCache = {};
|
|||
* @returns {number} Hash value
|
||||
*/
|
||||
function getStringHash(str) {
|
||||
// Check if the hash is already in the cache
|
||||
if (hashCache.hasOwnProperty(str)) {
|
||||
return hashCache[str];
|
||||
}
|
||||
// Check if the hash is already in the cache
|
||||
if (hashCache.hasOwnProperty(str)) {
|
||||
return hashCache[str];
|
||||
}
|
||||
|
||||
// Calculate the hash value
|
||||
const hash = calculateHash(str);
|
||||
// Calculate the hash value
|
||||
const hash = calculateHash(str);
|
||||
|
||||
// Store the hash in the cache
|
||||
hashCache[str] = hash;
|
||||
// Store the hash in the cache
|
||||
hashCache[str] = hash;
|
||||
|
||||
return hash;
|
||||
return hash;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -79,6 +82,9 @@ function getStringHash(str) {
|
|||
*/
|
||||
async function rearrangeChat(chat) {
|
||||
try {
|
||||
// Clear the extension prompt
|
||||
setExtensionPrompt(EXTENSION_PROMPT_TAG, '', extension_prompt_types.IN_PROMPT, 0);
|
||||
|
||||
if (!settings.enabled) {
|
||||
return;
|
||||
}
|
||||
|
@ -127,6 +133,11 @@ async function rearrangeChat(chat) {
|
|||
}
|
||||
}
|
||||
|
||||
if (queriedMessages.length === 0) {
|
||||
console.debug('Vectors: No relevant messages found');
|
||||
return;
|
||||
}
|
||||
|
||||
// Format queried messages into a single string
|
||||
const queriedText = 'Past events: ' + queriedMessages.map(x => collapseNewlines(`${x.name}: ${x.mes}`).trim()).join('\n\n');
|
||||
setExtensionPrompt(EXTENSION_PROMPT_TAG, queriedText, extension_prompt_types.IN_PROMPT, 0);
|
||||
|
@ -171,7 +182,10 @@ async function getSavedHashes(collectionId) {
|
|||
const response = await fetch('/api/vector/list', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ collectionId }),
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
source: settings.source,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
@ -192,7 +206,11 @@ async function insertVectorItems(collectionId, items) {
|
|||
const response = await fetch('/api/vector/insert', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ collectionId, items }),
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
items: items,
|
||||
source: settings.source,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
@ -210,7 +228,11 @@ async function deleteVectorItems(collectionId, hashes) {
|
|||
const response = await fetch('/api/vector/delete', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ collectionId, hashes }),
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
hashes: hashes,
|
||||
source: settings.source,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
@ -228,7 +250,12 @@ async function queryCollection(collectionId, searchText, topK) {
|
|||
const response = await fetch('/api/vector/query', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ collectionId, searchText, topK }),
|
||||
body: JSON.stringify({
|
||||
collectionId: collectionId,
|
||||
searchText: searchText,
|
||||
topK: topK,
|
||||
source: settings.source,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
@ -251,6 +278,11 @@ jQuery(async () => {
|
|||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
$('#vectors_source').val(settings.source).on('change', () => {
|
||||
settings.source = String($('#vectors_source').val());
|
||||
Object.assign(extension_settings.vectors, settings);
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
|
||||
eventSource.on(event_types.CHAT_CHANGED, onChatEvent);
|
||||
eventSource.on(event_types.MESSAGE_DELETED, onChatEvent);
|
||||
|
|
|
@ -9,6 +9,13 @@
|
|||
<input id="vectors_enabled" type="checkbox" class="checkbox">
|
||||
Enabled
|
||||
</label>
|
||||
<label for="vectors_source">
|
||||
Source
|
||||
</label>
|
||||
<select id="vectors_source" class="select">
|
||||
<option value="local">Local</option>
|
||||
<option value="openai">OpenAI</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
@ -629,6 +629,12 @@ function populateChatCompletion(prompts, chatCompletion, { bias, quietPrompt, ty
|
|||
if (true === afterScenario) chatCompletion.insert(authorsNote, 'scenario');
|
||||
}
|
||||
|
||||
// Vectors Memory
|
||||
if (prompts.has('vectorsMemory')) {
|
||||
const vectorsMemory = Message.fromPrompt(prompts.get('vectorsMemory'));
|
||||
chatCompletion.insert(vectorsMemory, 'main');
|
||||
}
|
||||
|
||||
// Decide whether dialogue examples should always be added
|
||||
if (power_user.pin_examples) {
|
||||
populateDialogueExamples(prompts, chatCompletion);
|
||||
|
|
120
server.js
120
server.js
|
@ -68,6 +68,7 @@ const characterCardParser = require('./src/character-card-parser.js');
|
|||
const contentManager = require('./src/content-manager');
|
||||
const novelai = require('./src/novelai');
|
||||
const statsHelpers = require('./statsHelpers.js');
|
||||
const { writeSecret, readSecret, readSecretState, migrateSecrets, SECRET_KEYS, getAllSecrets } = require('./src/secrets');
|
||||
|
||||
function createDefaultFiles() {
|
||||
const files = {
|
||||
|
@ -325,6 +326,7 @@ function humanizedISO8601DateTime(date) {
|
|||
var charactersPath = 'public/characters/';
|
||||
var chatsPath = 'public/chats/';
|
||||
const UPLOADS_PATH = './uploads';
|
||||
const SETTINGS_FILE = './public/settings.json';
|
||||
const AVATAR_WIDTH = 400;
|
||||
const AVATAR_HEIGHT = 600;
|
||||
const jsonParser = express.json({ limit: '100mb' });
|
||||
|
@ -4070,7 +4072,7 @@ const setupTasks = async function () {
|
|||
console.log(`SillyTavern ${version.pkgVersion}` + (version.gitBranch ? ` '${version.gitBranch}' (${version.gitRevision})` : ''));
|
||||
|
||||
backupSettings();
|
||||
migrateSecrets();
|
||||
migrateSecrets(SETTINGS_FILE);
|
||||
ensurePublicDirectoriesExist();
|
||||
await ensureThumbnailCache();
|
||||
contentManager.checkForNewContent();
|
||||
|
@ -4221,69 +4223,6 @@ function ensurePublicDirectoriesExist() {
|
|||
}
|
||||
}
|
||||
|
||||
const SECRETS_FILE = './secrets.json';
|
||||
const SETTINGS_FILE = './public/settings.json';
|
||||
const SECRET_KEYS = {
|
||||
HORDE: 'api_key_horde',
|
||||
MANCER: 'api_key_mancer',
|
||||
OPENAI: 'api_key_openai',
|
||||
NOVEL: 'api_key_novel',
|
||||
CLAUDE: 'api_key_claude',
|
||||
DEEPL: 'deepl',
|
||||
LIBRE: 'libre',
|
||||
LIBRE_URL: 'libre_url',
|
||||
OPENROUTER: 'api_key_openrouter',
|
||||
SCALE: 'api_key_scale',
|
||||
AI21: 'api_key_ai21',
|
||||
SCALE_COOKIE: 'scale_cookie',
|
||||
}
|
||||
|
||||
function migrateSecrets() {
|
||||
if (!fs.existsSync(SETTINGS_FILE)) {
|
||||
console.log('Settings file does not exist');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
let modified = false;
|
||||
const fileContents = fs.readFileSync(SETTINGS_FILE, 'utf8');
|
||||
const settings = JSON.parse(fileContents);
|
||||
const oaiKey = settings?.api_key_openai;
|
||||
const hordeKey = settings?.horde_settings?.api_key;
|
||||
const novelKey = settings?.api_key_novel;
|
||||
|
||||
if (typeof oaiKey === 'string') {
|
||||
console.log('Migrating OpenAI key...');
|
||||
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
|
||||
delete settings.api_key_openai;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (typeof hordeKey === 'string') {
|
||||
console.log('Migrating Horde key...');
|
||||
writeSecret(SECRET_KEYS.HORDE, hordeKey);
|
||||
delete settings.horde_settings.api_key;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (typeof novelKey === 'string') {
|
||||
console.log('Migrating Novel key...');
|
||||
writeSecret(SECRET_KEYS.NOVEL, novelKey);
|
||||
delete settings.api_key_novel;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
console.log('Writing updated settings.json...');
|
||||
const settingsContent = JSON.stringify(settings);
|
||||
writeFileAtomicSync(SETTINGS_FILE, settingsContent, "utf-8");
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Could not migrate secrets file. Proceed with caution.');
|
||||
}
|
||||
}
|
||||
|
||||
app.post('/writesecret', jsonParser, (request, response) => {
|
||||
const key = request.body.key;
|
||||
const value = request.body.value;
|
||||
|
@ -4293,19 +4232,9 @@ app.post('/writesecret', jsonParser, (request, response) => {
|
|||
});
|
||||
|
||||
app.post('/readsecretstate', jsonParser, (_, response) => {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
return response.send({});
|
||||
}
|
||||
|
||||
try {
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
const state = {};
|
||||
|
||||
for (const key of Object.values(SECRET_KEYS)) {
|
||||
state[key] = !!secrets[key]; // convert to boolean
|
||||
}
|
||||
|
||||
const state = readSecretState();
|
||||
return response.send(state);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -4351,14 +4280,13 @@ app.post('/viewsecrets', jsonParser, async (_, response) => {
|
|||
return response.sendStatus(403);
|
||||
}
|
||||
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
console.error('secrets.json does not exist');
|
||||
return response.sendStatus(404);
|
||||
}
|
||||
|
||||
try {
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
const secrets = getAllSecrets();
|
||||
|
||||
if (!secrets) {
|
||||
return response.sendStatus(404);
|
||||
}
|
||||
|
||||
return response.send(secrets);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -4797,6 +4725,11 @@ app.post('/libre_translate', jsonParser, async (request, response) => {
|
|||
const key = readSecret(SECRET_KEYS.LIBRE);
|
||||
const url = readSecret(SECRET_KEYS.LIBRE_URL);
|
||||
|
||||
if (!url) {
|
||||
console.log('LibreTranslate URL is not configured.');
|
||||
return response.sendStatus(401);
|
||||
}
|
||||
|
||||
const text = request.body.text;
|
||||
const lang = request.body.lang;
|
||||
|
||||
|
@ -5292,27 +5225,6 @@ function importRisuSprites(data) {
|
|||
}
|
||||
}
|
||||
|
||||
function writeSecret(key, value) {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
const emptyFile = JSON.stringify({});
|
||||
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
secrets[key] = value;
|
||||
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
|
||||
}
|
||||
|
||||
function readSecret(key) {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
return secrets[key];
|
||||
}
|
||||
|
||||
async function readAllChunks(readableStream) {
|
||||
return new Promise((resolve, reject) => {
|
||||
|
@ -5385,8 +5297,6 @@ async function getImageBuffers(zipFilePath) {
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* This function extracts the extension information from the manifest file.
|
||||
* @param {string} extensionPath - The path of the extension folder
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
|
||||
require('@tensorflow/tfjs');
|
||||
const encoder = require('@tensorflow-models/universal-sentence-encoder');
|
||||
|
||||
/**
|
||||
* Lazy loading class for the embedding model.
|
||||
*/
|
||||
class EmbeddingModel {
|
||||
/**
|
||||
* @type {encoder.UniversalSentenceEncoder} - The embedding model
|
||||
*/
|
||||
model;
|
||||
|
||||
async get() {
|
||||
if (!this.model) {
|
||||
this.model = await encoder.load();
|
||||
}
|
||||
|
||||
return this.model;
|
||||
}
|
||||
}
|
||||
|
||||
const model = new EmbeddingModel();
|
||||
|
||||
/**
|
||||
* @param {string} text
|
||||
*/
|
||||
async function getLocalVector(text) {
|
||||
const use = await model.get();
|
||||
const tensor = await use.embed(text);
|
||||
const vector = Array.from(await tensor.data());
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getLocalVector,
|
||||
};
|
|
@ -0,0 +1,47 @@
|
|||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from OpenAI ada model
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text) {
|
||||
const key = readSecret(SECRET_KEYS.OPENAI);
|
||||
|
||||
if (!key) {
|
||||
console.log('No OpenAI key found');
|
||||
throw new Error('No OpenAI key found');
|
||||
}
|
||||
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: text,
|
||||
model: 'text-embedding-ada-002',
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.log('OpenAI request failed');
|
||||
throw new Error('OpenAI request failed');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data?.data[0]?.embedding;
|
||||
|
||||
if (!Array.isArray(vector)) {
|
||||
console.log('OpenAI response was not an array');
|
||||
throw new Error('OpenAI response was not an array');
|
||||
}
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIVector,
|
||||
};
|
|
@ -0,0 +1,146 @@
|
|||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const writeFileAtomicSync = require('write-file-atomic').sync;
|
||||
|
||||
const SECRETS_FILE = path.join(process.cwd(), './secrets.json');
|
||||
const SECRET_KEYS = {
|
||||
HORDE: 'api_key_horde',
|
||||
MANCER: 'api_key_mancer',
|
||||
OPENAI: 'api_key_openai',
|
||||
NOVEL: 'api_key_novel',
|
||||
CLAUDE: 'api_key_claude',
|
||||
DEEPL: 'deepl',
|
||||
LIBRE: 'libre',
|
||||
LIBRE_URL: 'libre_url',
|
||||
OPENROUTER: 'api_key_openrouter',
|
||||
SCALE: 'api_key_scale',
|
||||
AI21: 'api_key_ai21',
|
||||
SCALE_COOKIE: 'scale_cookie',
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a secret to the secrets file
|
||||
* @param {string} key Secret key
|
||||
* @param {string} value Secret value
|
||||
*/
|
||||
function writeSecret(key, value) {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
const emptyFile = JSON.stringify({});
|
||||
writeFileAtomicSync(SECRETS_FILE, emptyFile, "utf-8");
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
secrets[key] = value;
|
||||
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets), "utf-8");
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads a secret from the secrets file
|
||||
* @param {string} key Secret key
|
||||
* @returns {string} Secret value
|
||||
*/
|
||||
function readSecret(key) {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
return secrets[key];
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the secret state from the secrets file
|
||||
* @returns {object} Secret state
|
||||
*/
|
||||
function readSecretState() {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
const state = {};
|
||||
|
||||
for (const key of Object.values(SECRET_KEYS)) {
|
||||
state[key] = !!secrets[key]; // convert to boolean
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Migrates secrets from settings.json to secrets.json
|
||||
* @param {string} settingsFile Path to settings.json
|
||||
* @returns {void}
|
||||
*/
|
||||
function migrateSecrets(settingsFile) {
|
||||
if (!fs.existsSync(settingsFile)) {
|
||||
console.log('Settings file does not exist');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
let modified = false;
|
||||
const fileContents = fs.readFileSync(settingsFile, 'utf8');
|
||||
const settings = JSON.parse(fileContents);
|
||||
const oaiKey = settings?.api_key_openai;
|
||||
const hordeKey = settings?.horde_settings?.api_key;
|
||||
const novelKey = settings?.api_key_novel;
|
||||
|
||||
if (typeof oaiKey === 'string') {
|
||||
console.log('Migrating OpenAI key...');
|
||||
writeSecret(SECRET_KEYS.OPENAI, oaiKey);
|
||||
delete settings.api_key_openai;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (typeof hordeKey === 'string') {
|
||||
console.log('Migrating Horde key...');
|
||||
writeSecret(SECRET_KEYS.HORDE, hordeKey);
|
||||
delete settings.horde_settings.api_key;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (typeof novelKey === 'string') {
|
||||
console.log('Migrating Novel key...');
|
||||
writeSecret(SECRET_KEYS.NOVEL, novelKey);
|
||||
delete settings.api_key_novel;
|
||||
modified = true;
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
console.log('Writing updated settings.json...');
|
||||
const settingsContent = JSON.stringify(settings);
|
||||
writeFileAtomicSync(settingsFile, settingsContent, "utf-8");
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Could not migrate secrets file. Proceed with caution.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads all secrets from the secrets file
|
||||
* @returns {Record<string, string> | undefined} Secrets
|
||||
*/
|
||||
function getAllSecrets() {
|
||||
if (!fs.existsSync(SECRETS_FILE)) {
|
||||
console.log('Secrets file does not exist');
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const fileContents = fs.readFileSync(SECRETS_FILE, 'utf8');
|
||||
const secrets = JSON.parse(fileContents);
|
||||
return secrets;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
writeSecret,
|
||||
readSecret,
|
||||
readSecretState,
|
||||
migrateSecrets,
|
||||
getAllSecrets,
|
||||
SECRET_KEYS,
|
||||
};
|
|
@ -2,36 +2,32 @@ const express = require('express');
|
|||
const vectra = require('vectra');
|
||||
const path = require('path');
|
||||
const sanitize = require('sanitize-filename');
|
||||
require('@tensorflow/tfjs');
|
||||
const encoder = require('@tensorflow-models/universal-sentence-encoder');
|
||||
|
||||
/**
|
||||
* Lazy loading class for the embedding model.
|
||||
* Gets the vector for the given text from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
class EmbeddingModel {
|
||||
/**
|
||||
* @type {encoder.UniversalSentenceEncoder} - The embedding model
|
||||
*/
|
||||
model;
|
||||
|
||||
async get() {
|
||||
if (!this.model) {
|
||||
this.model = await encoder.load();
|
||||
}
|
||||
|
||||
return this.model;
|
||||
async function getVector(source, text) {
|
||||
switch (source) {
|
||||
case 'local':
|
||||
return require('./local-vectors').getLocalVector(text);
|
||||
case 'openai':
|
||||
return require('./openai-vectors').getOpenAIVector(text);
|
||||
}
|
||||
}
|
||||
|
||||
const model = new EmbeddingModel();
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the index for the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<vectra.LocalIndex>} - The index for the collection
|
||||
*/
|
||||
async function getIndex(collectionId) {
|
||||
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(collectionId)));
|
||||
async function getIndex(collectionId, source) {
|
||||
const index = new vectra.LocalIndex(path.join(process.cwd(), 'vectors', sanitize(source), sanitize(collectionId)));
|
||||
|
||||
if (!await index.isIndexCreated()) {
|
||||
await index.createIndex();
|
||||
|
@ -43,19 +39,18 @@ async function getIndex(collectionId) {
|
|||
/**
|
||||
* Inserts items into the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {{ hash: number; text: string; }[]} items - The items to insert
|
||||
*/
|
||||
async function insertVectorItems(collectionId, items) {
|
||||
const index = await getIndex(collectionId);
|
||||
const use = await model.get();
|
||||
async function insertVectorItems(collectionId, source, items) {
|
||||
const index = await getIndex(collectionId, source);
|
||||
|
||||
await index.beginUpdate();
|
||||
|
||||
for (const item of items) {
|
||||
const text = item.text;
|
||||
const hash = item.hash;
|
||||
const tensor = await use.embed(text);
|
||||
const vector = Array.from(await tensor.data());
|
||||
const vector = await getVector(source, text);
|
||||
await index.upsertItem({ vector: vector, metadata: { hash, text } });
|
||||
}
|
||||
|
||||
|
@ -65,10 +60,11 @@ async function insertVectorItems(collectionId, items) {
|
|||
/**
|
||||
* Gets the hashes of the items in the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The hashes of the items in the collection
|
||||
*/
|
||||
async function getSavedHashes(collectionId) {
|
||||
const index = await getIndex(collectionId);
|
||||
async function getSavedHashes(collectionId, source) {
|
||||
const index = await getIndex(collectionId, source);
|
||||
|
||||
const items = await index.listItems();
|
||||
const hashes = items.map(x => Number(x.metadata.hash));
|
||||
|
@ -79,10 +75,11 @@ async function getSavedHashes(collectionId) {
|
|||
/**
|
||||
* Deletes items from the vector collection by hash
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {number[]} hashes - The hashes of the items to delete
|
||||
*/
|
||||
async function deleteVectorItems(collectionId, hashes) {
|
||||
const index = await getIndex(collectionId);
|
||||
async function deleteVectorItems(collectionId, source, hashes) {
|
||||
const index = await getIndex(collectionId, source);
|
||||
const items = await index.listItemsByMetadata({ hash: { '$in': hashes } });
|
||||
|
||||
await index.beginUpdate();
|
||||
|
@ -97,15 +94,14 @@ async function deleteVectorItems(collectionId, hashes) {
|
|||
/**
|
||||
* Gets the hashes of the items in the vector collection that match the search text
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<number[]>} - The hashes of the items that match the search text
|
||||
*/
|
||||
async function queryCollection(collectionId, searchText, topK) {
|
||||
const index = await getIndex(collectionId);
|
||||
const use = await model.get();
|
||||
const tensor = await use.embed(searchText);
|
||||
const vector = Array.from(await tensor.data());
|
||||
async function queryCollection(collectionId, source, searchText, topK) {
|
||||
const index = await getIndex(collectionId, source);
|
||||
const vector = await getVector(source, searchText);
|
||||
|
||||
const result = await index.queryItems(vector, topK);
|
||||
const hashes = result.map(x => Number(x.item.metadata.hash));
|
||||
|
@ -127,8 +123,9 @@ async function registerEndpoints(app, jsonParser) {
|
|||
const collectionId = String(req.body.collectionId);
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const source = String(req.body.source) || 'local';
|
||||
|
||||
const results = await queryCollection(collectionId, searchText, topK);
|
||||
const results = await queryCollection(collectionId, source, searchText, topK);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -144,8 +141,9 @@ async function registerEndpoints(app, jsonParser) {
|
|||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
|
||||
const source = String(req.body.source) || 'local';
|
||||
|
||||
await insertVectorItems(collectionId, items);
|
||||
await insertVectorItems(collectionId, source, items);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -160,8 +158,9 @@ async function registerEndpoints(app, jsonParser) {
|
|||
}
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const source = String(req.body.source) || 'local';
|
||||
|
||||
const hashes = await getSavedHashes(collectionId);
|
||||
const hashes = await getSavedHashes(collectionId, source);
|
||||
return res.json(hashes);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -177,8 +176,9 @@ async function registerEndpoints(app, jsonParser) {
|
|||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const hashes = req.body.hashes.map(x => Number(x));
|
||||
const source = String(req.body.source) || 'local';
|
||||
|
||||
await deleteVectorItems(collectionId, hashes);
|
||||
await deleteVectorItems(collectionId, source, hashes);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
|
Loading…
Reference in New Issue