Add backend for transformers.js whisper

This commit is contained in:
Cohee
2024-02-02 00:36:40 +02:00
parent 695b438c0d
commit 4b845dd442
6 changed files with 114 additions and 25 deletions

View File

@@ -53,6 +53,7 @@ extras:
captioningModel: Xenova/vit-gpt2-image-captioning
embeddingModel: Cohee/jina-embeddings-v2-base-en
promptExpansionModel: Cohee/fooocus_expansion-onnx
speechToTextModel: Xenova/whisper-small
# -- OPENAI CONFIGURATION --
openai:
# Will send a random user ID to OpenAI completion API

46
package-lock.json generated
View File

@@ -37,9 +37,10 @@
"png-chunks-extract": "^1.0.0",
"response-time": "^2.3.2",
"sanitize-filename": "^1.6.3",
"sillytavern-transformers": "^2.7.3",
"sillytavern-transformers": "^2.14.6",
"simple-git": "^3.19.1",
"vectra": "^0.2.2",
"wavefile": "^11.0.0",
"write-file-atomic": "^5.0.1",
"ws": "^8.13.0",
"yaml": "^2.3.4",
@@ -232,6 +233,14 @@
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
}
},
"node_modules/@huggingface/jinja": {
"version": "0.1.2",
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.1.2.tgz",
"integrity": "sha512-x5mpbfJt1nKmVep5WNP5VjNsjWApWNj8pPYI+uYMkBWH9bWUJmQmHt2lbf0VCoQd54Oq3XuFEh/UyoVh7rPxmg==",
"engines": {
"node": ">=18"
}
},
"node_modules/@humanwhocodes/config-array": {
"version": "0.11.13",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz",
@@ -3670,20 +3679,6 @@
"resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz",
"integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew=="
},
"node_modules/onnxruntime-node": {
"version": "1.14.0",
"resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz",
"integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==",
"optional": true,
"os": [
"win32",
"darwin",
"linux"
],
"dependencies": {
"onnxruntime-common": "~1.14.0"
}
},
"node_modules/onnxruntime-web": {
"version": "1.14.0",
"resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz",
@@ -4681,15 +4676,13 @@
}
},
"node_modules/sillytavern-transformers": {
"version": "2.7.3",
"resolved": "https://registry.npmjs.org/sillytavern-transformers/-/sillytavern-transformers-2.7.3.tgz",
"integrity": "sha512-vr6BQdLlT3TbCLJdzLt5Sc/MzZ7LWoTzdkkQJgtvKwU3sX1TcnW0Oz23hl211sefWdxwkj/g0RZdvL18hk1Jew==",
"version": "2.14.6",
"resolved": "https://registry.npmjs.org/sillytavern-transformers/-/sillytavern-transformers-2.14.6.tgz",
"integrity": "sha512-Tpu3lcDfa3vQB/wRgF+7ZG8ZNtYygT6vEQs9+4BpXLghVanx6ic7rBSxmTxx9Sm90G1P3W8mxoVkzfs8KAvMiA==",
"dependencies": {
"@huggingface/jinja": "^0.1.0",
"jimp": "^0.22.10",
"onnxruntime-web": "1.14.0"
},
"optionalDependencies": {
"onnxruntime-node": "1.14.0"
}
},
"node_modules/simple-concat": {
@@ -5152,6 +5145,17 @@
"vectra": "bin/vectra.js"
}
},
"node_modules/wavefile": {
"version": "11.0.0",
"resolved": "https://registry.npmjs.org/wavefile/-/wavefile-11.0.0.tgz",
"integrity": "sha512-/OBiAALgWU24IG7sC84cDO/KfFuvajWc5Uec0oV2zrpOOZZDgGdOwHwgEzOrwh8jkubBk7PtZfQBIcI1OaE5Ng==",
"bin": {
"wavefile": "bin/wavefile.js"
},
"engines": {
"node": ">=8"
}
},
"node_modules/web-streams-polyfill": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz",

View File

@@ -27,9 +27,10 @@
"png-chunks-extract": "^1.0.0",
"response-time": "^2.3.2",
"sanitize-filename": "^1.6.3",
"sillytavern-transformers": "^2.7.3",
"sillytavern-transformers": "^2.14.6",
"simple-git": "^3.19.1",
"vectra": "^0.2.2",
"wavefile": "^11.0.0",
"write-file-atomic": "^5.0.1",
"ws": "^8.13.0",
"yaml": "^2.3.4",

View File

@@ -593,6 +593,9 @@ app.use('/api/backends/chat-completions', require('./src/endpoints/backends/chat
// Scale (alt method)
app.use('/api/backends/scale-alt', require('./src/endpoints/backends/scale-alt').router);
// Speech (text-to-speech and speech-to-text)
app.use('/api/speech', require('./src/endpoints/speech').router);
const tavernUrl = new URL(
(cliArguments.ssl ? 'https://' : 'http://') +
(listen ? '0.0.0.0' : '127.0.0.1') +

74
src/endpoints/speech.js Normal file
View File

@@ -0,0 +1,74 @@
const express = require('express');
const { jsonParser } = require('../express-common');
const router = express.Router();
/**
* Gets the audio data from a base64-encoded audio file.
* @param {string} audio Base64-encoded audio
* @returns {Float64Array} Audio data
*/
function getWaveFile(audio) {
const wavefile = require('wavefile');
const wav = new wavefile.WaveFile();
wav.fromDataURI(audio);
wav.toBitDepth('32f');
wav.toSampleRate(16000);
let audioData = wav.getSamples();
if (Array.isArray(audioData)) {
if (audioData.length > 1) {
const SCALING_FACTOR = Math.sqrt(2);
// Merge channels (into first channel to save memory)
for (let i = 0; i < audioData[0].length; ++i) {
audioData[0][i] = SCALING_FACTOR * (audioData[0][i] + audioData[1][i]) / 2;
}
}
// Select first channel
audioData = audioData[0];
}
return audioData;
}
router.post('/recognize', jsonParser, async (req, res) => {
try {
const TASK = 'automatic-speech-recognition';
const { model, audio, lang } = req.body;
const module = await import('../transformers.mjs');
const pipe = await module.default.getPipeline(TASK, model);
const wav = getWaveFile(audio);
const start = performance.now();
const result = await pipe(wav, { language: lang || null });
const end = performance.now();
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
console.log('Transcribed audio:', result.text);
return res.json({ text: result.text });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
router.post('/synthesize', jsonParser, async (req, res) => {
try {
const TASK = 'text-to-speech';
const { model, text, lang } = req.body;
const module = await import('../transformers.mjs');
const pipe = await module.default.getPipeline(TASK, model);
const start = performance.now();
const result = await pipe(text, { language: lang || null });
const end = performance.now();
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
console.log('Synthesized audio:', result.audio);
return res.json({ audio: result.audio });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
module.exports = { router };

View File

@@ -33,6 +33,11 @@ const tasks = {
pipeline: null,
configField: 'extras.promptExpansionModel',
},
'automatic-speech-recognition': {
defaultModel: 'Xenova/whisper-small',
pipeline: null,
configField: 'extras.speechToTextModel',
},
}
/**
@@ -72,16 +77,17 @@ function getModelForTask(task) {
/**
* Gets the transformers.js pipeline for a given task.
* @param {string} task The task to get the pipeline for
* @param {import('sillytavern-transformers').PipelineType} task The task to get the pipeline for
* @param {string} forceModel The model to use for the pipeline, if any
* @returns {Promise<Pipeline>} Pipeline for the task
*/
async function getPipeline(task) {
async function getPipeline(task, forceModel = '') {
if (tasks[task].pipeline) {
return tasks[task].pipeline;
}
const cache_dir = path.join(process.cwd(), 'cache');
const model = getModelForTask(task);
const model = forceModel || getModelForTask(task);
const localOnly = getConfigValue('extras.disableAutoDownload', false);
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });