parent
0cc048cb64
commit
6ad786f348
|
@ -17,12 +17,15 @@ const securityOverride = false;
|
||||||
|
|
||||||
// Additional settings for extra modules / extensions
|
// Additional settings for extra modules / extensions
|
||||||
const extras = {
|
const extras = {
|
||||||
|
// Disables auto-download of models from the HuggingFace Hub.
|
||||||
|
// You will need to manually download the models and put them into the /cache folder.
|
||||||
|
disableAutoDownload: false,
|
||||||
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
|
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
|
||||||
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
|
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
|
||||||
// Image captioning model. HuggingFace ID of a model in ONNX format.
|
// Image captioning model. HuggingFace ID of a model in ONNX format.
|
||||||
captioningModel: 'Xenova/vit-gpt2-image-captioning',
|
captioningModel: 'Xenova/vit-gpt2-image-captioning',
|
||||||
// Feature extraction model. HuggingFace ID of a model in ONNX format.
|
// Feature extraction model. HuggingFace ID of a model in ONNX format.
|
||||||
embeddingModel: 'Xenova/all-mpnet-base-v2,
|
embeddingModel: 'Xenova/all-mpnet-base-v2',
|
||||||
};
|
};
|
||||||
|
|
||||||
// Request overrides for additional headers
|
// Request overrides for additional headers
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
Vectorization Source
|
Vectorization Source
|
||||||
</label>
|
</label>
|
||||||
<select id="vectors_source" class="select">
|
<select id="vectors_source" class="select">
|
||||||
<option value="local">Local</option>
|
<option value="transformers">Local (Transformers)</option>
|
||||||
|
<option value="local">Local (Tensorflow)</option>
|
||||||
<option value="openai">OpenAI</option>
|
<option value="openai">OpenAI</option>
|
||||||
</select>
|
</select>
|
||||||
<div id="vectors_advanced_settings" data-newbie-hidden>
|
<div id="vectors_advanced_settings" data-newbie-hidden>
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
const TASK = 'feature-extraction';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {string} text - The text to vectorize
|
||||||
|
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||||
|
*/
|
||||||
|
async function getTransformersVector(text) {
|
||||||
|
const module = await import('./transformers.mjs');
|
||||||
|
const pipe = await module.default.getPipeline(TASK);
|
||||||
|
const result = await pipe(text, { pooling: 'mean', normalize: true });
|
||||||
|
const vector = Array.from(result.data);
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getTransformersVector,
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
import { pipeline, env, RawImage } from 'sillytavern-transformers';
|
import { pipeline, env, RawImage } from 'sillytavern-transformers';
|
||||||
import { getConfig } from './util.js';
|
import { getConfigValue } from './util.js';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import _ from 'lodash';
|
import _ from 'lodash';
|
||||||
|
|
||||||
|
@ -43,8 +43,7 @@ function getModelForTask(task) {
|
||||||
const defaultModel = tasks[task].defaultModel;
|
const defaultModel = tasks[task].defaultModel;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const config = getConfig();
|
const model = getConfigValue(tasks[task].configField, null);
|
||||||
const model = _.get(config, tasks[task].configField, null);
|
|
||||||
return model || defaultModel;
|
return model || defaultModel;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.warn('Failed to read config.conf, using default classification model.');
|
console.warn('Failed to read config.conf, using default classification model.');
|
||||||
|
@ -52,11 +51,6 @@ function getModelForTask(task) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function progressCallback() {
|
|
||||||
// TODO: Implement progress callback
|
|
||||||
// console.log(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function getPipeline(task) {
|
async function getPipeline(task) {
|
||||||
if (tasks[task].pipeline) {
|
if (tasks[task].pipeline) {
|
||||||
return tasks[task].pipeline;
|
return tasks[task].pipeline;
|
||||||
|
@ -64,8 +58,9 @@ async function getPipeline(task) {
|
||||||
|
|
||||||
const cache_dir = path.join(process.cwd(), 'cache');
|
const cache_dir = path.join(process.cwd(), 'cache');
|
||||||
const model = getModelForTask(task);
|
const model = getModelForTask(task);
|
||||||
|
const localOnly = getConfigValue('extras.disableAutoDownload', false);
|
||||||
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
|
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
|
||||||
const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback });
|
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
|
||||||
tasks[task].pipeline = instance;
|
tasks[task].pipeline = instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
13
src/util.js
13
src/util.js
|
@ -1,6 +1,7 @@
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const child_process = require('child_process');
|
const child_process = require('child_process');
|
||||||
const commandExistsSync = require('command-exists').sync;
|
const commandExistsSync = require('command-exists').sync;
|
||||||
|
const _ = require('lodash');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the config object from the config.conf file.
|
* Returns the config object from the config.conf file.
|
||||||
|
@ -16,6 +17,17 @@ function getConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the value for the given key from the config object.
|
||||||
|
* @param {string} key - Key to get from the config object
|
||||||
|
* @param {any} defaultValue - Default value to return if the key is not found
|
||||||
|
* @returns {any} Value for the given key
|
||||||
|
*/
|
||||||
|
function getConfigValue(key, defaultValue = null) {
|
||||||
|
const config = getConfig();
|
||||||
|
return _.get(config, key, defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Encodes the Basic Auth header value for the given user and password.
|
* Encodes the Basic Auth header value for the given user and password.
|
||||||
* @param {string} auth username:password
|
* @param {string} auth username:password
|
||||||
|
@ -67,6 +79,7 @@ function delay(ms) {
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getConfig,
|
getConfig,
|
||||||
|
getConfigValue,
|
||||||
getVersion,
|
getVersion,
|
||||||
getBasicAuthHeader,
|
getBasicAuthHeader,
|
||||||
delay,
|
delay,
|
||||||
|
|
|
@ -15,6 +15,8 @@ async function getVector(source, text) {
|
||||||
return require('./local-vectors').getLocalVector(text);
|
return require('./local-vectors').getLocalVector(text);
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return require('./openai-vectors').getOpenAIVector(text);
|
return require('./openai-vectors').getOpenAIVector(text);
|
||||||
|
case 'transformers':
|
||||||
|
return require('./embedding').getTransformersVector(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new Error(`Unknown vector source ${source}`);
|
throw new Error(`Unknown vector source ${source}`);
|
||||||
|
|
Loading…
Reference in New Issue