Remove tensorflow vector source.

This commit is contained in:
Cohee
2023-09-17 14:09:24 +03:00
parent 323493962a
commit dc1121b72a
6 changed files with 13 additions and 334 deletions

View File

@@ -1,38 +0,0 @@
require('@tensorflow/tfjs');
const encoder = require('@tensorflow-models/universal-sentence-encoder');
/**
* Lazy loading class for the embedding model.
*/
class EmbeddingModel {
/**
* @type {encoder.UniversalSentenceEncoder} - The embedding model
*/
model;
async get() {
if (!this.model) {
this.model = await encoder.load();
}
return this.model;
}
}
const model = new EmbeddingModel();
/**
* @param {string} text
*/
async function getLocalVector(text) {
const use = await model.get();
const tensor = await use.embed(text);
const vector = Array.from(await tensor.data());
return vector;
}
module.exports = {
getLocalVector,
};

View File

@@ -11,8 +11,6 @@ const sanitize = require('sanitize-filename');
*/
async function getVector(source, text) {
switch (source) {
case 'local':
return require('./local-vectors').getLocalVector(text);
case 'openai':
return require('./openai-vectors').getOpenAIVector(text);
case 'transformers':
@@ -126,7 +124,7 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const source = String(req.body.source) || 'local';
const source = String(req.body.source) || 'transformers';
const results = await queryCollection(collectionId, source, searchText, topK);
return res.json(results);
@@ -144,7 +142,7 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
const source = String(req.body.source) || 'local';
const source = String(req.body.source) || 'transformers';
await insertVectorItems(collectionId, source, items);
return res.sendStatus(200);
@@ -161,7 +159,7 @@ async function registerEndpoints(app, jsonParser) {
}
const collectionId = String(req.body.collectionId);
const source = String(req.body.source) || 'local';
const source = String(req.body.source) || 'transformers';
const hashes = await getSavedHashes(collectionId, source);
return res.json(hashes);
@@ -179,7 +177,7 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const hashes = req.body.hashes.map(x => Number(x));
const source = String(req.body.source) || 'local';
const source = String(req.body.source) || 'transformers';
await deleteVectorItems(collectionId, source, hashes);
return res.sendStatus(200);
@@ -197,7 +195,7 @@ async function registerEndpoints(app, jsonParser) {
const collectionId = String(req.body.collectionId);
const sources = ['local', 'openai'];
const sources = ['transformers', 'openai'];
for (const source of sources) {
const index = await getIndex(collectionId, source, false);