From 6ad786f348a6c524adf2711964a1d35ee96efe82 Mon Sep 17 00:00:00 2001
From: Cohee <18619528+Cohee1207@users.noreply.github.com>
Date: Thu, 14 Sep 2023 23:40:13 +0300
Subject: [PATCH] Add alternative local vectors source. x5 speed boost!!
---
default/config.conf | 5 ++++-
public/scripts/extensions/vectors/settings.html | 3 ++-
src/embedding.js | 17 +++++++++++++++++
src/transformers.mjs | 13 ++++---------
src/util.js | 13 +++++++++++++
src/vectors.js | 2 ++
6 files changed, 42 insertions(+), 11 deletions(-)
create mode 100644 src/embedding.js
diff --git a/default/config.conf b/default/config.conf
index c3a4f3dfc..e1bb5b1e3 100644
--- a/default/config.conf
+++ b/default/config.conf
@@ -17,12 +17,15 @@ const securityOverride = false;
// Additional settings for extra modules / extensions
const extras = {
+ // Disables auto-download of models from the HuggingFace Hub.
+ // You will need to manually download the models and put them into the /cache folder.
+ disableAutoDownload: false,
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
// Image captioning model. HuggingFace ID of a model in ONNX format.
captioningModel: 'Xenova/vit-gpt2-image-captioning',
// Feature extraction model. HuggingFace ID of a model in ONNX format.
- embeddingModel: 'Xenova/all-mpnet-base-v2,
+ embeddingModel: 'Xenova/all-mpnet-base-v2',
};
// Request overrides for additional headers
diff --git a/public/scripts/extensions/vectors/settings.html b/public/scripts/extensions/vectors/settings.html
index d51f47986..fff8289ca 100644
--- a/public/scripts/extensions/vectors/settings.html
+++ b/public/scripts/extensions/vectors/settings.html
@@ -13,7 +13,8 @@
Vectorization Source
diff --git a/src/embedding.js b/src/embedding.js
new file mode 100644
index 000000000..f7bb9c080
--- /dev/null
+++ b/src/embedding.js
@@ -0,0 +1,17 @@
+const TASK = 'feature-extraction';
+
+/**
+ * @param {string} text - The text to vectorize
+ * @returns {Promise} - The vectorized text in form of an array of numbers
+ */
+async function getTransformersVector(text) {
+ const module = await import('./transformers.mjs');
+ const pipe = await module.default.getPipeline(TASK);
+ const result = await pipe(text, { pooling: 'mean', normalize: true });
+ const vector = Array.from(result.data);
+ return vector;
+}
+
+module.exports = {
+ getTransformersVector,
+}
diff --git a/src/transformers.mjs b/src/transformers.mjs
index f31d8911d..c956f0d8a 100644
--- a/src/transformers.mjs
+++ b/src/transformers.mjs
@@ -1,5 +1,5 @@
import { pipeline, env, RawImage } from 'sillytavern-transformers';
-import { getConfig } from './util.js';
+import { getConfigValue } from './util.js';
import path from 'path';
import _ from 'lodash';
@@ -43,8 +43,7 @@ function getModelForTask(task) {
const defaultModel = tasks[task].defaultModel;
try {
- const config = getConfig();
- const model = _.get(config, tasks[task].configField, null);
+ const model = getConfigValue(tasks[task].configField, null);
return model || defaultModel;
} catch (error) {
console.warn('Failed to read config.conf, using default classification model.');
@@ -52,11 +51,6 @@ function getModelForTask(task) {
}
}
-function progressCallback() {
- // TODO: Implement progress callback
- // console.log(arguments);
-}
-
async function getPipeline(task) {
if (tasks[task].pipeline) {
return tasks[task].pipeline;
@@ -64,8 +58,9 @@ async function getPipeline(task) {
const cache_dir = path.join(process.cwd(), 'cache');
const model = getModelForTask(task);
+ const localOnly = getConfigValue('extras.disableAutoDownload', false);
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
- const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback });
+ const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
tasks[task].pipeline = instance;
return instance;
}
diff --git a/src/util.js b/src/util.js
index 3df16a896..3350548fc 100644
--- a/src/util.js
+++ b/src/util.js
@@ -1,6 +1,7 @@
const path = require('path');
const child_process = require('child_process');
const commandExistsSync = require('command-exists').sync;
+const _ = require('lodash');
/**
* Returns the config object from the config.conf file.
@@ -16,6 +17,17 @@ function getConfig() {
}
}
+/**
+ * Returns the value for the given key from the config object.
+ * @param {string} key - Key to get from the config object
+ * @param {any} defaultValue - Default value to return if the key is not found
+ * @returns {any} Value for the given key
+ */
+function getConfigValue(key, defaultValue = null) {
+ const config = getConfig();
+ return _.get(config, key, defaultValue);
+}
+
/**
* Encodes the Basic Auth header value for the given user and password.
* @param {string} auth username:password
@@ -67,6 +79,7 @@ function delay(ms) {
module.exports = {
getConfig,
+ getConfigValue,
getVersion,
getBasicAuthHeader,
delay,
diff --git a/src/vectors.js b/src/vectors.js
index d42a812f9..98bb1cd30 100644
--- a/src/vectors.js
+++ b/src/vectors.js
@@ -15,6 +15,8 @@ async function getVector(source, text) {
return require('./local-vectors').getLocalVector(text);
case 'openai':
return require('./openai-vectors').getOpenAIVector(text);
+ case 'transformers':
+ return require('./embedding').getTransformersVector(text);
}
throw new Error(`Unknown vector source ${source}`);