From c76c76410c3a1bc67d97894e19138380599c4e6c Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Mon, 11 Sep 2023 01:25:22 +0300 Subject: [PATCH] Add ability to override local classification model --- default/config.conf | 9 ++++++ src/classify.mjs | 67 +++++++++++++++++++++++++++++++-------------- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/default/config.conf b/default/config.conf index 962c6ce5b..5e9d2920f 100644 --- a/default/config.conf +++ b/default/config.conf @@ -15,7 +15,15 @@ const skipContentCheck = false; // If true, no new default content will be deliv // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. const securityOverride = false; +// Additional settings for extra modules / extensions +const extras = { + // Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format. + classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', +}; + // Request overrides for additional headers +// Format is an array of objects: +// { hosts: [ "" ], headers: {
: "" } } const requestOverrides = []; module.exports = { @@ -32,4 +40,5 @@ module.exports = { securityOverride, skipContentCheck, requestOverrides, + extras, }; diff --git a/src/classify.mjs b/src/classify.mjs index 97062d84c..7489494e9 100644 --- a/src/classify.mjs +++ b/src/classify.mjs @@ -1,22 +1,37 @@ -import { pipeline, TextClassificationPipeline, env } from 'sillytavern-transformers'; +import { pipeline, env } from 'sillytavern-transformers'; import path from 'path'; +// Limit the number of threads to 1 to avoid issues on Android env.backends.onnx.wasm.numThreads = 1; class PipelineAccessor { /** - * @type {TextClassificationPipeline} + * @type {import("sillytavern-transformers").TextClassificationPipeline} */ pipe; async get() { if (!this.pipe) { const cache_dir = path.join(process.cwd(), 'cache'); - this.pipe = await pipeline('text-classification', 'Cohee/distilbert-base-uncased-go-emotions-onnx', { cache_dir, quantized: true }); + const model = this.getClassificationModel(); + this.pipe = await pipeline('text-classification', model, { cache_dir, quantized: true }); } return this.pipe; } + + getClassificationModel() { + const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; + + try { + const config = require(path.join(process.cwd(), './config.conf')); + const model = config?.extras?.classificationModel; + return model || DEFAULT_MODEL; + } catch { + console.warn('Failed to read config.conf, using default classification model.'); + return DEFAULT_MODEL; + } + } } /** @@ -28,30 +43,40 @@ function registerEndpoints(app, jsonParser) { const pipelineAccessor = new PipelineAccessor(); app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { - const pipe = await pipelineAccessor.get(); - const result = Object.keys(pipe.model.config.label2id); - return res.json({ labels: result }); + try { + const pipe = await pipelineAccessor.get(); + const result = Object.keys(pipe.model.config.label2id); + return res.json({ labels: result }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } }); app.post('/api/extra/classify', jsonParser, async (req, res) => { - const { text } = req.body; + try { + const { text } = req.body; - async function getResult(text) { - if (cacheObject.hasOwnProperty(text)) { - return cacheObject[text]; - } else { - const pipe = await pipelineAccessor.get(); - const result = await pipe(text); - cacheObject[text] = result; - return result; + async function getResult(text) { + if (cacheObject.hasOwnProperty(text)) { + return cacheObject[text]; + } else { + const pipe = await pipelineAccessor.get(); + const result = await pipe(text); + cacheObject[text] = result; + return result; + } } + + console.log('Classify input:', text); + const result = await getResult(text); + console.log('Classify output:', result); + + return res.json({ classification: result }); + } catch (error) { + console.error(error); + return res.sendStatus(500); } - - console.log('Classify input:', text); - const result = await getResult(text); - console.log('Classify output:', result); - - return res.json({ classification: result }); }); }