diff --git a/default/config.yaml b/default/config.yaml index be82ae6f7..bb4847c30 100644 --- a/default/config.yaml +++ b/default/config.yaml @@ -53,6 +53,7 @@ extras: captioningModel: Xenova/vit-gpt2-image-captioning embeddingModel: Cohee/jina-embeddings-v2-base-en promptExpansionModel: Cohee/fooocus_expansion-onnx + speechToTextModel: Xenova/whisper-small # -- OPENAI CONFIGURATION -- openai: # Will send a random user ID to OpenAI completion API diff --git a/package-lock.json b/package-lock.json index 36cfb9528..896d6b124 100644 --- a/package-lock.json +++ b/package-lock.json @@ -37,9 +37,10 @@ "png-chunks-extract": "^1.0.0", "response-time": "^2.3.2", "sanitize-filename": "^1.6.3", - "sillytavern-transformers": "^2.7.3", + "sillytavern-transformers": "^2.14.6", "simple-git": "^3.19.1", "vectra": "^0.2.2", + "wavefile": "^11.0.0", "write-file-atomic": "^5.0.1", "ws": "^8.13.0", "yaml": "^2.3.4", @@ -232,6 +233,14 @@ "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, + "node_modules/@huggingface/jinja": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.1.2.tgz", + "integrity": "sha512-x5mpbfJt1nKmVep5WNP5VjNsjWApWNj8pPYI+uYMkBWH9bWUJmQmHt2lbf0VCoQd54Oq3XuFEh/UyoVh7rPxmg==", + "engines": { + "node": ">=18" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.13", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", @@ -3670,20 +3679,6 @@ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz", "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==" }, - "node_modules/onnxruntime-node": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz", - "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==", - "optional": true, - "os": [ - "win32", - "darwin", - "linux" - ], - "dependencies": { - "onnxruntime-common": "~1.14.0" - } - }, "node_modules/onnxruntime-web": { "version": "1.14.0", "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz", @@ -4681,15 +4676,13 @@ } }, "node_modules/sillytavern-transformers": { - "version": "2.7.3", - "resolved": "https://registry.npmjs.org/sillytavern-transformers/-/sillytavern-transformers-2.7.3.tgz", - "integrity": "sha512-vr6BQdLlT3TbCLJdzLt5Sc/MzZ7LWoTzdkkQJgtvKwU3sX1TcnW0Oz23hl211sefWdxwkj/g0RZdvL18hk1Jew==", + "version": "2.14.6", + "resolved": "https://registry.npmjs.org/sillytavern-transformers/-/sillytavern-transformers-2.14.6.tgz", + "integrity": "sha512-Tpu3lcDfa3vQB/wRgF+7ZG8ZNtYygT6vEQs9+4BpXLghVanx6ic7rBSxmTxx9Sm90G1P3W8mxoVkzfs8KAvMiA==", "dependencies": { + "@huggingface/jinja": "^0.1.0", "jimp": "^0.22.10", "onnxruntime-web": "1.14.0" - }, - "optionalDependencies": { - "onnxruntime-node": "1.14.0" } }, "node_modules/simple-concat": { @@ -5152,6 +5145,17 @@ "vectra": "bin/vectra.js" } }, + "node_modules/wavefile": { + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/wavefile/-/wavefile-11.0.0.tgz", + "integrity": "sha512-/OBiAALgWU24IG7sC84cDO/KfFuvajWc5Uec0oV2zrpOOZZDgGdOwHwgEzOrwh8jkubBk7PtZfQBIcI1OaE5Ng==", + "bin": { + "wavefile": "bin/wavefile.js" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/web-streams-polyfill": { "version": "3.2.1", "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz", diff --git a/package.json b/package.json index 1d8eb4781..63773d25f 100644 --- a/package.json +++ b/package.json @@ -27,9 +27,10 @@ "png-chunks-extract": "^1.0.0", "response-time": "^2.3.2", "sanitize-filename": "^1.6.3", - "sillytavern-transformers": "^2.7.3", + "sillytavern-transformers": "^2.14.6", "simple-git": "^3.19.1", "vectra": "^0.2.2", + "wavefile": "^11.0.0", "write-file-atomic": "^5.0.1", "ws": "^8.13.0", "yaml": "^2.3.4", diff --git a/server.js b/server.js index bebcc7bb0..de637b7d7 100644 --- a/server.js +++ b/server.js @@ -593,6 +593,9 @@ app.use('/api/backends/chat-completions', require('./src/endpoints/backends/chat // Scale (alt method) app.use('/api/backends/scale-alt', require('./src/endpoints/backends/scale-alt').router); +// Speech (text-to-speech and speech-to-text) +app.use('/api/speech', require('./src/endpoints/speech').router); + const tavernUrl = new URL( (cliArguments.ssl ? 'https://' : 'http://') + (listen ? '0.0.0.0' : '127.0.0.1') + diff --git a/src/endpoints/speech.js b/src/endpoints/speech.js new file mode 100644 index 000000000..de5e758c9 --- /dev/null +++ b/src/endpoints/speech.js @@ -0,0 +1,74 @@ +const express = require('express'); +const { jsonParser } = require('../express-common'); + +const router = express.Router(); + +/** + * Gets the audio data from a base64-encoded audio file. + * @param {string} audio Base64-encoded audio + * @returns {Float64Array} Audio data + */ +function getWaveFile(audio) { + const wavefile = require('wavefile'); + const wav = new wavefile.WaveFile(); + wav.fromDataURI(audio); + wav.toBitDepth('32f'); + wav.toSampleRate(16000); + let audioData = wav.getSamples(); + if (Array.isArray(audioData)) { + if (audioData.length > 1) { + const SCALING_FACTOR = Math.sqrt(2); + + // Merge channels (into first channel to save memory) + for (let i = 0; i < audioData[0].length; ++i) { + audioData[0][i] = SCALING_FACTOR * (audioData[0][i] + audioData[1][i]) / 2; + } + } + + // Select first channel + audioData = audioData[0]; + } + + return audioData; +} + +router.post('/recognize', jsonParser, async (req, res) => { + try { + const TASK = 'automatic-speech-recognition'; + const { model, audio, lang } = req.body; + const module = await import('../transformers.mjs'); + const pipe = await module.default.getPipeline(TASK, model); + const wav = getWaveFile(audio); + const start = performance.now(); + const result = await pipe(wav, { language: lang || null }); + const end = performance.now(); + console.log(`Execution duration: ${(end - start) / 1000} seconds`); + console.log('Transcribed audio:', result.text); + + return res.json({ text: result.text }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } +}); + +router.post('/synthesize', jsonParser, async (req, res) => { + try { + const TASK = 'text-to-speech'; + const { model, text, lang } = req.body; + const module = await import('../transformers.mjs'); + const pipe = await module.default.getPipeline(TASK, model); + const start = performance.now(); + const result = await pipe(text, { language: lang || null }); + const end = performance.now(); + console.log(`Execution duration: ${(end - start) / 1000} seconds`); + console.log('Synthesized audio:', result.audio); + + return res.json({ audio: result.audio }); + } catch (error) { + console.error(error); + return res.sendStatus(500); + } +}); + +module.exports = { router }; diff --git a/src/transformers.mjs b/src/transformers.mjs index db3649fb4..3a30edf62 100644 --- a/src/transformers.mjs +++ b/src/transformers.mjs @@ -33,6 +33,11 @@ const tasks = { pipeline: null, configField: 'extras.promptExpansionModel', }, + 'automatic-speech-recognition': { + defaultModel: 'Xenova/whisper-small', + pipeline: null, + configField: 'extras.speechToTextModel', + }, } /** @@ -72,16 +77,17 @@ function getModelForTask(task) { /** * Gets the transformers.js pipeline for a given task. - * @param {string} task The task to get the pipeline for + * @param {import('sillytavern-transformers').PipelineType} task The task to get the pipeline for + * @param {string} forceModel The model to use for the pipeline, if any * @returns {Promise} Pipeline for the task */ -async function getPipeline(task) { +async function getPipeline(task, forceModel = '') { if (tasks[task].pipeline) { return tasks[task].pipeline; } const cache_dir = path.join(process.cwd(), 'cache'); - const model = getModelForTask(task); + const model = forceModel || getModelForTask(task); const localOnly = getConfigValue('extras.disableAutoDownload', false); console.log('Initializing transformers.js pipeline for task', task, 'with model', model); const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });