(WIP) Local emotion classification pipeline

This commit is contained in:
Cohee
2023-09-09 15:14:16 +03:00
parent 4d08e3e9be
commit 967a084aad
7 changed files with 389 additions and 75 deletions

48
src/classify.mjs Normal file
View File

@@ -0,0 +1,48 @@
import { pipeline, TextClassificationPipeline } from '@xenova/transformers';
import path from 'path';
class PipelineAccessor {
/**
* @type {TextClassificationPipeline}
*/
pipe;
async get() {
if (!this.pipe) {
const cache_dir = path.join(process.cwd(), 'cache');
this.pipe = await pipeline('text-classification', 'Cohee/distilbert-base-uncased-go-emotions-onnx', { cache_dir, quantized: true });
}
return this.pipe;
}
}
/**
* @param {import("express").Express} app
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
const pipelineAccessor = new PipelineAccessor();
app.post('/api/extra/classify/labels', jsonParser, async (req, res) => {
const pipe = await pipelineAccessor.get();
const result = Object.keys(pipe.model.config.label2id);
return res.json({ labels: result });
});
app.post('/api/extra/classify', jsonParser, async (req, res) => {
const { text } = req.body;
const pipe = await pipelineAccessor.get();
const result = await pipe(text);
console.log('Classify input:', text);
console.log('Classify output:', result);
return res.json({ classification: result });
});
}
export default {
registerEndpoints,
};