From 6ad786f348a6c524adf2711964a1d35ee96efe82 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 14 Sep 2023 23:40:13 +0300 Subject: [PATCH] Add alternative local vectors source. x5 speed boost!! --- default/config.conf | 5 ++++- public/scripts/extensions/vectors/settings.html | 3 ++- src/embedding.js | 17 +++++++++++++++++ src/transformers.mjs | 13 ++++--------- src/util.js | 13 +++++++++++++ src/vectors.js | 2 ++ 6 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 src/embedding.js diff --git a/default/config.conf b/default/config.conf index c3a4f3dfc..e1bb5b1e3 100644 --- a/default/config.conf +++ b/default/config.conf @@ -17,12 +17,15 @@ const securityOverride = false; // Additional settings for extra modules / extensions const extras = { + // Disables auto-download of models from the HuggingFace Hub. + // You will need to manually download the models and put them into the /cache folder. + disableAutoDownload: false, // Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format. classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', // Image captioning model. HuggingFace ID of a model in ONNX format. captioningModel: 'Xenova/vit-gpt2-image-captioning', // Feature extraction model. HuggingFace ID of a model in ONNX format. - embeddingModel: 'Xenova/all-mpnet-base-v2, + embeddingModel: 'Xenova/all-mpnet-base-v2', }; // Request overrides for additional headers diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html index d51f47986..fff8289ca 100644 --- a/public/scripts/extensions/vectors/settings.html +++ b/public/scripts/extensions/vectors/settings.html @@ -13,7 +13,8 @@ Vectorization Source
diff --git a/src/embedding.js b/src/embedding.js new file mode 100644 index 000000000..f7bb9c080 --- /dev/null +++ b/src/embedding.js @@ -0,0 +1,17 @@ +const TASK = 'feature-extraction'; + +/** + * @param {string} text - The text to vectorize + * @returns {Promise} - The vectorized text in form of an array of numbers + */ +async function getTransformersVector(text) { + const module = await import('./transformers.mjs'); + const pipe = await module.default.getPipeline(TASK); + const result = await pipe(text, { pooling: 'mean', normalize: true }); + const vector = Array.from(result.data); + return vector; +} + +module.exports = { + getTransformersVector, +} diff --git a/src/transformers.mjs b/src/transformers.mjs index f31d8911d..c956f0d8a 100644 --- a/src/transformers.mjs +++ b/src/transformers.mjs @@ -1,5 +1,5 @@ import { pipeline, env, RawImage } from 'sillytavern-transformers'; -import { getConfig } from './util.js'; +import { getConfigValue } from './util.js'; import path from 'path'; import _ from 'lodash'; @@ -43,8 +43,7 @@ function getModelForTask(task) { const defaultModel = tasks[task].defaultModel; try { - const config = getConfig(); - const model = _.get(config, tasks[task].configField, null); + const model = getConfigValue(tasks[task].configField, null); return model || defaultModel; } catch (error) { console.warn('Failed to read config.conf, using default classification model.'); @@ -52,11 +51,6 @@ function getModelForTask(task) { } } -function progressCallback() { - // TODO: Implement progress callback - // console.log(arguments); -} - async function getPipeline(task) { if (tasks[task].pipeline) { return tasks[task].pipeline; @@ -64,8 +58,9 @@ async function getPipeline(task) { const cache_dir = path.join(process.cwd(), 'cache'); const model = getModelForTask(task); + const localOnly = getConfigValue('extras.disableAutoDownload', false); console.log('Initializing transformers.js pipeline for task', task, 'with model', model); - const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback }); + const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly }); tasks[task].pipeline = instance; return instance; } diff --git a/src/util.js b/src/util.js index 3df16a896..3350548fc 100644 --- a/src/util.js +++ b/src/util.js @@ -1,6 +1,7 @@ const path = require('path'); const child_process = require('child_process'); const commandExistsSync = require('command-exists').sync; +const _ = require('lodash'); /** * Returns the config object from the config.conf file. @@ -16,6 +17,17 @@ function getConfig() { } } +/** + * Returns the value for the given key from the config object. + * @param {string} key - Key to get from the config object + * @param {any} defaultValue - Default value to return if the key is not found + * @returns {any} Value for the given key + */ +function getConfigValue(key, defaultValue = null) { + const config = getConfig(); + return _.get(config, key, defaultValue); +} + /** * Encodes the Basic Auth header value for the given user and password. * @param {string} auth username:password @@ -67,6 +79,7 @@ function delay(ms) { module.exports = { getConfig, + getConfigValue, getVersion, getBasicAuthHeader, delay, diff --git a/src/vectors.js b/src/vectors.js index d42a812f9..98bb1cd30 100644 --- a/src/vectors.js +++ b/src/vectors.js @@ -15,6 +15,8 @@ async function getVector(source, text) { return require('./local-vectors').getLocalVector(text); case 'openai': return require('./openai-vectors').getOpenAIVector(text); + case 'transformers': + return require('./embedding').getTransformersVector(text); } throw new Error(`Unknown vector source ${source}`);