From 0cc048cb64c89b0e8aea1c94b3b50a4adda4ad20 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Thu, 14 Sep 2023 23:12:33 +0300 Subject: [PATCH] Refactor transformers.js usage --- default/config.conf | 2 + server.js | 17 ++++----- src/caption.js | 29 +++++++++++++++ src/caption.mjs | 72 ----------------------------------- src/classify.js | 53 ++++++++++++++++++++++++++ src/classify.mjs | 89 -------------------------------------------- src/transformers.mjs | 76 +++++++++++++++++++++++++++++++++++++ 7 files changed, 167 insertions(+), 171 deletions(-) create mode 100644 src/caption.js delete mode 100644 src/caption.mjs create mode 100644 src/classify.js delete mode 100644 src/classify.mjs create mode 100644 src/transformers.mjs diff --git a/default/config.conf b/default/config.conf index 87801a256..c3a4f3dfc 100644 --- a/default/config.conf +++ b/default/config.conf @@ -21,6 +21,8 @@ const extras = { classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', // Image captioning model. HuggingFace ID of a model in ONNX format. captioningModel: 'Xenova/vit-gpt2-image-captioning', + // Feature extraction model. HuggingFace ID of a model in ONNX format. + embeddingModel: 'Xenova/all-mpnet-base-v2, }; // Request overrides for additional headers diff --git a/server.js b/server.js index 4818a951f..ab064dfb6 100644 --- a/server.js +++ b/server.js @@ -5250,21 +5250,18 @@ app.post('/get_character_assets_list', jsonParser, async (request, response) => // Stable Diffusion generation require('./src/stable-diffusion').registerEndpoints(app, jsonParser); + // LLM and SD Horde generation require('./src/horde').registerEndpoints(app, jsonParser); + // Vector storage DB require('./src/vectors').registerEndpoints(app, jsonParser); + // Chat translation require('./src/translate').registerEndpoints(app, jsonParser); + // Emotion classification -import('./src/classify.mjs').then(module => { - module.default.registerEndpoints(app, jsonParser); -}).catch(err => { - console.error(err); -}); +require('./src/classify').registerEndpoints(app, jsonParser); + // Image captioning -import('./src/caption.mjs').then(module => { - module.default.registerEndpoints(app, jsonParser); -}).catch(err => { - console.error(err); -}); +require('./src/caption').registerEndpoints(app, jsonParser); diff --git a/src/caption.js b/src/caption.js new file mode 100644 index 000000000..e79581fac --- /dev/null +++ b/src/caption.js @@ -0,0 +1,29 @@ +const TASK = 'image-to-text'; + +/** + * @param {import("express").Express} app + * @param {any} jsonParser + */ +function registerEndpoints(app, jsonParser) { + app.post('/api/extra/caption', jsonParser, async (req, res) => { + try { + const { image } = req.body; + + const module = await import('./transformers.mjs'); + const rawImage = module.default.getRawImage(image); + const pipe = await module.default.getPipeline(TASK); + const result = await pipe(rawImage); + const text = result[0].generated_text; + console.log('Image caption:', text); + + return res.json({ caption: text }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } + }); +} + +module.exports = { + registerEndpoints, +}; diff --git a/src/caption.mjs b/src/caption.mjs deleted file mode 100644 index ab8241f42..000000000 --- a/src/caption.mjs +++ /dev/null @@ -1,72 +0,0 @@ -import { pipeline, env, RawImage } from 'sillytavern-transformers'; -import path from 'path'; -import { getConfig } from './util.js'; - -// Limit the number of threads to 1 to avoid issues on Android -env.backends.onnx.wasm.numThreads = 1; -// Use WASM from a local folder to avoid CDN connections -env.backends.onnx.wasm.wasmPaths = path.join(process.cwd(), 'dist') + path.sep; - -class PipelineAccessor { - /** - * @type {import("sillytavern-transformers").ImageToTextPipeline} - */ - pipe; - - async get() { - if (!this.pipe) { - const cache_dir = path.join(process.cwd(), 'cache'); - const model = this.getCaptioningModel(); - this.pipe = await pipeline('image-to-text', model, { cache_dir, quantized: true }); - } - - return this.pipe; - } - - getCaptioningModel() { - const DEFAULT_MODEL = 'Xenova/vit-gpt2-image-captioning'; - - try { - const config = getConfig(); - const model = config?.extras?.captioningModel; - return model || DEFAULT_MODEL; - } catch (error) { - console.warn('Failed to read config.conf, using default captioning model.'); - return DEFAULT_MODEL; - } - } -} - -/** - * @param {import("express").Express} app - * @param {any} jsonParser - */ -function registerEndpoints(app, jsonParser) { - const pipelineAccessor = new PipelineAccessor(); - - app.post('/api/extra/caption', jsonParser, async (req, res) => { - try { - const { image } = req.body; - - // base64 string to blob - const buffer = Buffer.from(image, 'base64'); - const byteArray = new Uint8Array(buffer); - const blob = new Blob([byteArray]); - - const rawImage = await RawImage.fromBlob(blob); - const pipe = await pipelineAccessor.get(); - const result = await pipe(rawImage); - const text = result[0].generated_text; - console.log('Image caption:', text); - - return res.json({ caption: text }); - } catch (error) { - console.error(error); - return res.sendStatus(500); - } - }); -} - -export default { - registerEndpoints, -}; diff --git a/src/classify.js b/src/classify.js new file mode 100644 index 000000000..d9a5a72a7 --- /dev/null +++ b/src/classify.js @@ -0,0 +1,53 @@ +const TASK = 'text-classification'; + +/** + * @param {import("express").Express} app + * @param {any} jsonParser + */ +function registerEndpoints(app, jsonParser) { + const cacheObject = {}; + + app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { + try { + const module = await import('./transformers.mjs'); + const pipe = await module.default.getPipeline(TASK); + const result = Object.keys(pipe.model.config.label2id); + return res.json({ labels: result }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } + }); + + app.post('/api/extra/classify', jsonParser, async (req, res) => { + try { + const { text } = req.body; + + async function getResult(text) { + if (cacheObject.hasOwnProperty(text)) { + return cacheObject[text]; + } else { + const module = await import('./transformers.mjs'); + const pipe = await module.default.getPipeline(TASK); + const result = await pipe(text, { topk: 5 }); + result.sort((a, b) => b.score - a.score); + cacheObject[text] = result; + return result; + } + } + + console.log('Classify input:', text); + const result = await getResult(text); + console.log('Classify output:', result); + + return res.json({ classification: result }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } + }); +} + +module.exports = { + registerEndpoints, +}; diff --git a/src/classify.mjs b/src/classify.mjs deleted file mode 100644 index b35965116..000000000 --- a/src/classify.mjs +++ /dev/null @@ -1,89 +0,0 @@ -import { pipeline, env } from 'sillytavern-transformers'; -import path from 'path'; -import { getConfig } from './util.js'; - -// Limit the number of threads to 1 to avoid issues on Android -env.backends.onnx.wasm.numThreads = 1; -// Use WASM from a local folder to avoid CDN connections -env.backends.onnx.wasm.wasmPaths = path.join(process.cwd(), 'dist') + path.sep; - -class PipelineAccessor { - /** - * @type {import("sillytavern-transformers").TextClassificationPipeline} - */ - pipe; - - async get() { - if (!this.pipe) { - const cache_dir = path.join(process.cwd(), 'cache'); - const model = this.getClassificationModel(); - this.pipe = await pipeline('text-classification', model, { cache_dir, quantized: true }); - } - - return this.pipe; - } - - getClassificationModel() { - const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; - - try { - const config = getConfig(); - const model = config?.extras?.classificationModel; - return model || DEFAULT_MODEL; - } catch (error) { - console.warn('Failed to read config.conf, using default classification model.'); - return DEFAULT_MODEL; - } - } -} - -/** - * @param {import("express").Express} app - * @param {any} jsonParser - */ -function registerEndpoints(app, jsonParser) { - const cacheObject = {}; - const pipelineAccessor = new PipelineAccessor(); - - app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { - try { - const pipe = await pipelineAccessor.get(); - const result = Object.keys(pipe.model.config.label2id); - return res.json({ labels: result }); - } catch (error) { - console.error(error); - return res.sendStatus(500); - } - }); - - app.post('/api/extra/classify', jsonParser, async (req, res) => { - try { - const { text } = req.body; - - async function getResult(text) { - if (cacheObject.hasOwnProperty(text)) { - return cacheObject[text]; - } else { - const pipe = await pipelineAccessor.get(); - const result = await pipe(text, { topk: 5 }); - result.sort((a, b) => b.score - a.score); - cacheObject[text] = result; - return result; - } - } - - console.log('Classify input:', text); - const result = await getResult(text); - console.log('Classify output:', result); - - return res.json({ classification: result }); - } catch (error) { - console.error(error); - return res.sendStatus(500); - } - }); -} - -export default { - registerEndpoints, -}; diff --git a/src/transformers.mjs b/src/transformers.mjs new file mode 100644 index 000000000..f31d8911d --- /dev/null +++ b/src/transformers.mjs @@ -0,0 +1,76 @@ +import { pipeline, env, RawImage } from 'sillytavern-transformers'; +import { getConfig } from './util.js'; +import path from 'path'; +import _ from 'lodash'; + +configureTransformers(); + +function configureTransformers() { + // Limit the number of threads to 1 to avoid issues on Android + env.backends.onnx.wasm.numThreads = 1; + // Use WASM from a local folder to avoid CDN connections + env.backends.onnx.wasm.wasmPaths = path.join(process.cwd(), 'dist') + path.sep; +} + +const tasks = { + 'text-classification': { + defaultModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', + pipeline: null, + configField: 'extras.classificationModel', + }, + 'image-to-text': { + defaultModel: 'Xenova/vit-gpt2-image-captioning', + pipeline: null, + configField: 'extras.captioningModel', + }, + 'feature-extraction': { + defaultModel: 'Xenova/all-mpnet-base-v2', + pipeline: null, + configField: 'extras.embeddingModel', + }, +} + +async function getRawImage(image) { + const buffer = Buffer.from(image, 'base64'); + const byteArray = new Uint8Array(buffer); + const blob = new Blob([byteArray]); + + const rawImage = await RawImage.fromBlob(blob); + return rawImage; +} + +function getModelForTask(task) { + const defaultModel = tasks[task].defaultModel; + + try { + const config = getConfig(); + const model = _.get(config, tasks[task].configField, null); + return model || defaultModel; + } catch (error) { + console.warn('Failed to read config.conf, using default classification model.'); + return defaultModel; + } +} + +function progressCallback() { + // TODO: Implement progress callback + // console.log(arguments); +} + +async function getPipeline(task) { + if (tasks[task].pipeline) { + return tasks[task].pipeline; + } + + const cache_dir = path.join(process.cwd(), 'cache'); + const model = getModelForTask(task); + console.log('Initializing transformers.js pipeline for task', task, 'with model', model); + const instance = await pipeline(task, model, { cache_dir, quantized: true, progress_callback: progressCallback }); + tasks[task].pipeline = instance; + return instance; +} + +export default { + getPipeline, + getRawImage, +}