parent
0cc048cb64
commit
6ad786f348
|
@ -17,12 +17,15 @@ const securityOverride = false;
|
|||
|
||||
// Additional settings for extra modules / extensions
|
||||
const extras = {
|
||||
// Disables auto-download of models from the HuggingFace Hub.
|
||||
// You will need to manually download the models and put them into the /cache folder.
|
||||
disableAutoDownload: false,
|
||||
// Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format.
|
||||
classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
|
||||
// Image captioning model. HuggingFace ID of a model in ONNX format.
|
||||
captioningModel: 'Xenova/vit-gpt2-image-captioning',
|
||||
// Feature extraction model. HuggingFace ID of a model in ONNX format.
|
||||
embeddingModel: 'Xenova/all-mpnet-base-v2,
|
||||
embeddingModel: 'Xenova/all-mpnet-base-v2',
|
||||
};
|
||||
|
||||
// Request overrides for additional headers
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
Vectorization Source
|
||||
</label>
|
||||
<select id="vectors_source" class="select">
|
||||
<option value="local">Local</option>
|
||||
<option value="transformers">Local (Transformers)</option>
|
||||
<option value="local">Local (Tensorflow)</option>
|
||||
<option value="openai">OpenAI</option>
|
||||
</select>
|
||||
<div id="vectors_advanced_settings" data-newbie-hidden>
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
const TASK = 'feature-extraction';
|
||||
|
||||
/**
|
||||
* @param {string} text - The text to vectorize
|
||||
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||
*/
|
||||
async function getTransformersVector(text) {
|
||||
const module = await import('./transformers.mjs');
|
||||
const pipe = await module.default.getPipeline(TASK);
|
||||
const result = await pipe(text, { pooling: 'mean', normalize: true });
|
||||
const vector = Array.from(result.data);
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getTransformersVector,
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
import { pipeline, env, RawImage } from 'sillytavern-transformers';
|
||||
import { getConfig } from './util.js';
|
||||
import { getConfigValue } from './util.js';
|
||||
import path from 'path';
|
||||
import _ from 'lodash';
|
||||
|
||||
|
@ -43,8 +43,7 @@ function getModelForTask(task) {
|
|||
const defaultModel = tasks[task].defaultModel;
|
||||
|
||||
try {
|
||||
const config = getConfig();
|
||||
const model = _.get(config, tasks[task].configField, null);
|
||||
const model = getConfigValue(tasks[task].configField, null);
|
||||
return model || defaultModel;
|
||||
} catch (error) {
|
||||
console.warn('Failed to read config.conf, using default classification model.');
|
||||
|
@ -52,11 +51,6 @@ function getModelForTask(task) {
|
|||
}
|
||||
}
|
||||
|
||||
function progressCallback() {
|
||||
// TODO: Implement progress callback
|
||||
// console.log(arguments);
|
||||
}
|
||||
|
||||
async function getPipeline(task) {
|
||||
if (tasks[task].pipeline) {
|
||||
return tasks[task].pipeline;
|
||||
|
@ -64,8 +58,9 @@ async function getPipeline(task) {
|
|||
|
||||
const cache_dir = path.join(process.cwd(), 'cache');
|
||||
const model = getModelForTask(task);
|
||||
const localOnly = getConfigValue('extras.disableAutoDownload', false);
|
||||
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
|
||||
const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback });
|
||||
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
|
||||
tasks[task].pipeline = instance;
|
||||
return instance;
|
||||
}
|
||||
|
|
13
src/util.js
13
src/util.js
|
@ -1,6 +1,7 @@
|
|||
const path = require('path');
|
||||
const child_process = require('child_process');
|
||||
const commandExistsSync = require('command-exists').sync;
|
||||
const _ = require('lodash');
|
||||
|
||||
/**
|
||||
* Returns the config object from the config.conf file.
|
||||
|
@ -16,6 +17,17 @@ function getConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value for the given key from the config object.
|
||||
* @param {string} key - Key to get from the config object
|
||||
* @param {any} defaultValue - Default value to return if the key is not found
|
||||
* @returns {any} Value for the given key
|
||||
*/
|
||||
function getConfigValue(key, defaultValue = null) {
|
||||
const config = getConfig();
|
||||
return _.get(config, key, defaultValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Encodes the Basic Auth header value for the given user and password.
|
||||
* @param {string} auth username:password
|
||||
|
@ -67,6 +79,7 @@ function delay(ms) {
|
|||
|
||||
module.exports = {
|
||||
getConfig,
|
||||
getConfigValue,
|
||||
getVersion,
|
||||
getBasicAuthHeader,
|
||||
delay,
|
||||
|
|
|
@ -15,6 +15,8 @@ async function getVector(source, text) {
|
|||
return require('./local-vectors').getLocalVector(text);
|
||||
case 'openai':
|
||||
return require('./openai-vectors').getOpenAIVector(text);
|
||||
case 'transformers':
|
||||
return require('./embedding').getTransformersVector(text);
|
||||
}
|
||||
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
|
|
Loading…
Reference in New Issue