mirror of
				https://github.com/SillyTavern/SillyTavern.git
				synced 2025-06-05 21:59:27 +02:00 
			
		
		
		
	Add ability to override local classification model
This commit is contained in:
		| @@ -15,7 +15,15 @@ const skipContentCheck = false; // If true, no new default content will be deliv | |||||||
| // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. | // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. | ||||||
| const securityOverride = false; | const securityOverride = false; | ||||||
|  |  | ||||||
|  | // Additional settings for extra modules / extensions | ||||||
|  | const extras = { | ||||||
|  |     // Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format. | ||||||
|  |     classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', | ||||||
|  | }; | ||||||
|  |  | ||||||
| // Request overrides for additional headers | // Request overrides for additional headers | ||||||
|  | // Format is an array of objects: | ||||||
|  | // { hosts: [ "<url>" ], headers: { <header>: "<value>" } } | ||||||
| const requestOverrides = []; | const requestOverrides = []; | ||||||
|  |  | ||||||
| module.exports = { | module.exports = { | ||||||
| @@ -32,4 +40,5 @@ module.exports = { | |||||||
|     securityOverride, |     securityOverride, | ||||||
|     skipContentCheck, |     skipContentCheck, | ||||||
|     requestOverrides, |     requestOverrides, | ||||||
|  |     extras, | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -1,22 +1,37 @@ | |||||||
| import { pipeline, TextClassificationPipeline, env } from 'sillytavern-transformers'; | import { pipeline, env } from 'sillytavern-transformers'; | ||||||
| import path from 'path'; | import path from 'path'; | ||||||
|  |  | ||||||
|  | // Limit the number of threads to 1 to avoid issues on Android | ||||||
| env.backends.onnx.wasm.numThreads = 1; | env.backends.onnx.wasm.numThreads = 1; | ||||||
|  |  | ||||||
| class PipelineAccessor { | class PipelineAccessor { | ||||||
|     /** |     /** | ||||||
|      * @type {TextClassificationPipeline} |      * @type {import("sillytavern-transformers").TextClassificationPipeline} | ||||||
|      */ |      */ | ||||||
|     pipe; |     pipe; | ||||||
|  |  | ||||||
|     async get() { |     async get() { | ||||||
|         if (!this.pipe) { |         if (!this.pipe) { | ||||||
|             const cache_dir = path.join(process.cwd(), 'cache'); |             const cache_dir = path.join(process.cwd(), 'cache'); | ||||||
|             this.pipe = await pipeline('text-classification', 'Cohee/distilbert-base-uncased-go-emotions-onnx', { cache_dir, quantized: true }); |             const model = this.getClassificationModel(); | ||||||
|  |             this.pipe = await pipeline('text-classification', model, { cache_dir, quantized: true }); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return this.pipe; |         return this.pipe; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     getClassificationModel() { | ||||||
|  |         const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; | ||||||
|  |  | ||||||
|  |         try { | ||||||
|  |             const config = require(path.join(process.cwd(), './config.conf')); | ||||||
|  |             const model = config?.extras?.classificationModel; | ||||||
|  |             return model || DEFAULT_MODEL; | ||||||
|  |         } catch { | ||||||
|  |             console.warn('Failed to read config.conf, using default classification model.'); | ||||||
|  |             return DEFAULT_MODEL; | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| /** | /** | ||||||
| @@ -28,12 +43,18 @@ function registerEndpoints(app, jsonParser) { | |||||||
|     const pipelineAccessor = new PipelineAccessor(); |     const pipelineAccessor = new PipelineAccessor(); | ||||||
|  |  | ||||||
|     app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { |     app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { | ||||||
|  |         try { | ||||||
|             const pipe = await pipelineAccessor.get(); |             const pipe = await pipelineAccessor.get(); | ||||||
|             const result = Object.keys(pipe.model.config.label2id); |             const result = Object.keys(pipe.model.config.label2id); | ||||||
|             return res.json({ labels: result }); |             return res.json({ labels: result }); | ||||||
|  |         } catch (error) { | ||||||
|  |             console.error(error); | ||||||
|  |             return res.sendStatus(500); | ||||||
|  |         } | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|     app.post('/api/extra/classify', jsonParser, async (req, res) => { |     app.post('/api/extra/classify', jsonParser, async (req, res) => { | ||||||
|  |         try { | ||||||
|             const { text } = req.body; |             const { text } = req.body; | ||||||
|  |  | ||||||
|             async function getResult(text) { |             async function getResult(text) { | ||||||
| @@ -52,6 +73,10 @@ function registerEndpoints(app, jsonParser) { | |||||||
|             console.log('Classify output:', result); |             console.log('Classify output:', result); | ||||||
|  |  | ||||||
|             return res.json({ classification: result }); |             return res.json({ classification: result }); | ||||||
|  |         } catch (error) { | ||||||
|  |             console.error(error); | ||||||
|  |             return res.sendStatus(500); | ||||||
|  |         } | ||||||
|     }); |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user