mirror of
				https://github.com/SillyTavern/SillyTavern.git
				synced 2025-06-05 21:59:27 +02:00 
			
		
		
		
	Add ability to override local classification model
This commit is contained in:
		| @@ -15,7 +15,15 @@ const skipContentCheck = false; // If true, no new default content will be deliv | ||||
| // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. | ||||
| const securityOverride = false; | ||||
|  | ||||
| // Additional settings for extra modules / extensions | ||||
| const extras = { | ||||
|     // Text classification model for sentiment analysis. HuggingFace ID of a model in ONNX format. | ||||
|     classificationModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', | ||||
| }; | ||||
|  | ||||
| // Request overrides for additional headers | ||||
| // Format is an array of objects: | ||||
| // { hosts: [ "<url>" ], headers: { <header>: "<value>" } } | ||||
| const requestOverrides = []; | ||||
|  | ||||
| module.exports = { | ||||
| @@ -32,4 +40,5 @@ module.exports = { | ||||
|     securityOverride, | ||||
|     skipContentCheck, | ||||
|     requestOverrides, | ||||
|     extras, | ||||
| }; | ||||
|   | ||||
| @@ -1,22 +1,37 @@ | ||||
| import { pipeline, TextClassificationPipeline, env } from 'sillytavern-transformers'; | ||||
| import { pipeline, env } from 'sillytavern-transformers'; | ||||
| import path from 'path'; | ||||
|  | ||||
| // Limit the number of threads to 1 to avoid issues on Android | ||||
| env.backends.onnx.wasm.numThreads = 1; | ||||
|  | ||||
| class PipelineAccessor { | ||||
|     /** | ||||
|      * @type {TextClassificationPipeline} | ||||
|      * @type {import("sillytavern-transformers").TextClassificationPipeline} | ||||
|      */ | ||||
|     pipe; | ||||
|  | ||||
|     async get() { | ||||
|         if (!this.pipe) { | ||||
|             const cache_dir = path.join(process.cwd(), 'cache'); | ||||
|             this.pipe = await pipeline('text-classification', 'Cohee/distilbert-base-uncased-go-emotions-onnx', { cache_dir, quantized: true }); | ||||
|             const model = this.getClassificationModel(); | ||||
|             this.pipe = await pipeline('text-classification', model, { cache_dir, quantized: true }); | ||||
|         } | ||||
|  | ||||
|         return this.pipe; | ||||
|     } | ||||
|  | ||||
|     getClassificationModel() { | ||||
|         const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; | ||||
|  | ||||
|         try { | ||||
|             const config = require(path.join(process.cwd(), './config.conf')); | ||||
|             const model = config?.extras?.classificationModel; | ||||
|             return model || DEFAULT_MODEL; | ||||
|         } catch { | ||||
|             console.warn('Failed to read config.conf, using default classification model.'); | ||||
|             return DEFAULT_MODEL; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /** | ||||
| @@ -28,30 +43,40 @@ function registerEndpoints(app, jsonParser) { | ||||
|     const pipelineAccessor = new PipelineAccessor(); | ||||
|  | ||||
|     app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { | ||||
|         const pipe = await pipelineAccessor.get(); | ||||
|         const result = Object.keys(pipe.model.config.label2id); | ||||
|         return res.json({ labels: result }); | ||||
|         try { | ||||
|             const pipe = await pipelineAccessor.get(); | ||||
|             const result = Object.keys(pipe.model.config.label2id); | ||||
|             return res.json({ labels: result }); | ||||
|         } catch (error) { | ||||
|             console.error(error); | ||||
|             return res.sendStatus(500); | ||||
|         } | ||||
|     }); | ||||
|  | ||||
|     app.post('/api/extra/classify', jsonParser, async (req, res) => { | ||||
|         const { text } = req.body; | ||||
|         try { | ||||
|             const { text } = req.body; | ||||
|  | ||||
|         async function getResult(text) { | ||||
|             if (cacheObject.hasOwnProperty(text)) { | ||||
|                 return cacheObject[text]; | ||||
|             } else { | ||||
|                 const pipe = await pipelineAccessor.get(); | ||||
|                 const result = await pipe(text); | ||||
|                 cacheObject[text] = result; | ||||
|                 return result; | ||||
|             async function getResult(text) { | ||||
|                 if (cacheObject.hasOwnProperty(text)) { | ||||
|                     return cacheObject[text]; | ||||
|                 } else { | ||||
|                     const pipe = await pipelineAccessor.get(); | ||||
|                     const result = await pipe(text); | ||||
|                     cacheObject[text] = result; | ||||
|                     return result; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             console.log('Classify input:', text); | ||||
|             const result = await getResult(text); | ||||
|             console.log('Classify output:', result); | ||||
|  | ||||
|             return res.json({ classification: result }); | ||||
|         } catch (error) { | ||||
|             console.error(error); | ||||
|             return res.sendStatus(500); | ||||
|         } | ||||
|  | ||||
|         console.log('Classify input:', text); | ||||
|         const result = await getResult(text); | ||||
|         console.log('Classify output:', result); | ||||
|  | ||||
|         return res.json({ classification: result }); | ||||
|     }); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user