Add alternative local vectors source.

x5 speed boost!!
This commit is contained in:
Cohee 2023-09-14 23:40:13 +03:00
parent 0cc048cb64
commit 6ad786f348
6 changed files with 42 additions and 11 deletions

View File

@ -17,12 +17,15 @@ const securityOverride = false;
// Additional settings for extra modules / extensions
const extras = {
// Disables auto-download of models from the HuggingFace Hub.
// You will need to manually download the models and put them into the /cache folder.
disableAutoDownload: false,
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
// Image captioning model. HuggingFace ID of a model in ONNX format.
captioningModel: 'Xenova/vit-gpt2-image-captioning',
// Feature extraction model. HuggingFace ID of a model in ONNX format.
embeddingModel: 'Xenova/all-mpnet-base-v2,
embeddingModel: 'Xenova/all-mpnet-base-v2',
};
// Request overrides for additional headers

View File

@ -13,7 +13,8 @@
Vectorization Source
</label>
<select id="vectors_source" class="select">
<option value="local">Local</option>
<option value="transformers">Local (Transformers)</option>
<option value="local">Local (Tensorflow)</option>
<option value="openai">OpenAI</option>
</select>
<div id="vectors_advanced_settings" data-newbie-hidden>

17
src/embedding.js Normal file
View File

@ -0,0 +1,17 @@
const TASK = 'feature-extraction';
/**
* @param {string} text - The text to vectorize
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
*/
async function getTransformersVector(text) {
const module = await import('./transformers.mjs');
const pipe = await module.default.getPipeline(TASK);
const result = await pipe(text, { pooling: 'mean', normalize: true });
const vector = Array.from(result.data);
return vector;
}
module.exports = {
getTransformersVector,
}

View File

@ -1,5 +1,5 @@
import { pipeline, env, RawImage } from 'sillytavern-transformers';
import { getConfig } from './util.js';
import { getConfigValue } from './util.js';
import path from 'path';
import _ from 'lodash';
@ -43,8 +43,7 @@ function getModelForTask(task) {
const defaultModel = tasks[task].defaultModel;
try {
const config = getConfig();
const model = _.get(config, tasks[task].configField, null);
const model = getConfigValue(tasks[task].configField, null);
return model || defaultModel;
} catch (error) {
console.warn('Failed to read config.conf, using default classification model.');
@ -52,11 +51,6 @@ function getModelForTask(task) {
}
}
function progressCallback() {
// TODO: Implement progress callback
// console.log(arguments);
}
async function getPipeline(task) {
if (tasks[task].pipeline) {
return tasks[task].pipeline;
@ -64,8 +58,9 @@ async function getPipeline(task) {
const cache_dir = path.join(process.cwd(), 'cache');
const model = getModelForTask(task);
const localOnly = getConfigValue('extras.disableAutoDownload', false);
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback });
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
tasks[task].pipeline = instance;
return instance;
}

View File

@ -1,6 +1,7 @@
const path = require('path');
const child_process = require('child_process');
const commandExistsSync = require('command-exists').sync;
const _ = require('lodash');
/**
* Returns the config object from the config.conf file.
@ -16,6 +17,17 @@ function getConfig() {
}
}
/**
* Returns the value for the given key from the config object.
* @param {string} key - Key to get from the config object
* @param {any} defaultValue - Default value to return if the key is not found
* @returns {any} Value for the given key
*/
function getConfigValue(key, defaultValue = null) {
const config = getConfig();
return _.get(config, key, defaultValue);
}
/**
* Encodes the Basic Auth header value for the given user and password.
* @param {string} auth username:password
@ -67,6 +79,7 @@ function delay(ms) {
module.exports = {
getConfig,
getConfigValue,
getVersion,
getBasicAuthHeader,
delay,

View File

@ -15,6 +15,8 @@ async function getVector(source, text) {
return require('./local-vectors').getLocalVector(text);
case 'openai':
return require('./openai-vectors').getOpenAIVector(text);
case 'transformers':
return require('./embedding').getTransformersVector(text);
}
throw new Error(`Unknown vector source ${source}`);