mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Remove tensorflow vector source.
This commit is contained in:
@@ -1,38 +0,0 @@
|
||||
|
||||
require('@tensorflow/tfjs');
|
||||
const encoder = require('@tensorflow-models/universal-sentence-encoder');
|
||||
|
||||
/**
|
||||
* Lazy loading class for the embedding model.
|
||||
*/
|
||||
class EmbeddingModel {
|
||||
/**
|
||||
* @type {encoder.UniversalSentenceEncoder} - The embedding model
|
||||
*/
|
||||
model;
|
||||
|
||||
async get() {
|
||||
if (!this.model) {
|
||||
this.model = await encoder.load();
|
||||
}
|
||||
|
||||
return this.model;
|
||||
}
|
||||
}
|
||||
|
||||
const model = new EmbeddingModel();
|
||||
|
||||
/**
|
||||
* @param {string} text
|
||||
*/
|
||||
async function getLocalVector(text) {
|
||||
const use = await model.get();
|
||||
const tensor = await use.embed(text);
|
||||
const vector = Array.from(await tensor.data());
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getLocalVector,
|
||||
};
|
@@ -11,8 +11,6 @@ const sanitize = require('sanitize-filename');
|
||||
*/
|
||||
async function getVector(source, text) {
|
||||
switch (source) {
|
||||
case 'local':
|
||||
return require('./local-vectors').getLocalVector(text);
|
||||
case 'openai':
|
||||
return require('./openai-vectors').getOpenAIVector(text);
|
||||
case 'transformers':
|
||||
@@ -126,7 +124,7 @@ async function registerEndpoints(app, jsonParser) {
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const source = String(req.body.source) || 'local';
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
const results = await queryCollection(collectionId, source, searchText, topK);
|
||||
return res.json(results);
|
||||
@@ -144,7 +142,7 @@ async function registerEndpoints(app, jsonParser) {
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text }));
|
||||
const source = String(req.body.source) || 'local';
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
await insertVectorItems(collectionId, source, items);
|
||||
return res.sendStatus(200);
|
||||
@@ -161,7 +159,7 @@ async function registerEndpoints(app, jsonParser) {
|
||||
}
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const source = String(req.body.source) || 'local';
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
const hashes = await getSavedHashes(collectionId, source);
|
||||
return res.json(hashes);
|
||||
@@ -179,7 +177,7 @@ async function registerEndpoints(app, jsonParser) {
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const hashes = req.body.hashes.map(x => Number(x));
|
||||
const source = String(req.body.source) || 'local';
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
|
||||
await deleteVectorItems(collectionId, source, hashes);
|
||||
return res.sendStatus(200);
|
||||
@@ -197,7 +195,7 @@ async function registerEndpoints(app, jsonParser) {
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
|
||||
const sources = ['local', 'openai'];
|
||||
const sources = ['transformers', 'openai'];
|
||||
for (const source of sources) {
|
||||
const index = await getIndex(collectionId, source, false);
|
||||
|
||||
|
Reference in New Issue
Block a user