diff --git a/src/classify.mjs b/src/classify.mjs index 7489494e9..df2baca75 100644 --- a/src/classify.mjs +++ b/src/classify.mjs @@ -1,5 +1,6 @@ import { pipeline, env } from 'sillytavern-transformers'; import path from 'path'; +import { getConfig } from './util.js'; // Limit the number of threads to 1 to avoid issues on Android env.backends.onnx.wasm.numThreads = 1; @@ -24,10 +25,10 @@ class PipelineAccessor { const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; try { - const config = require(path.join(process.cwd(), './config.conf')); + const config = getConfig(); const model = config?.extras?.classificationModel; return model || DEFAULT_MODEL; - } catch { + } catch (error) { console.warn('Failed to read config.conf, using default classification model.'); return DEFAULT_MODEL; } @@ -62,7 +63,8 @@ function registerEndpoints(app, jsonParser) { return cacheObject[text]; } else { const pipe = await pipelineAccessor.get(); - const result = await pipe(text); + const result = await pipe(text, { topk: 5 }); + result.sort((a, b) => b.score - a.score); cacheObject[text] = result; return result; } diff --git a/src/middleware/basicAuthMiddleware.js b/src/middleware/basicAuthMiddleware.js index a5b1b3459..25036ca1e 100644 --- a/src/middleware/basicAuthMiddleware.js +++ b/src/middleware/basicAuthMiddleware.js @@ -2,12 +2,7 @@ * When applied, this middleware will ensure the request contains the required header for basic authentication and only * allow access to the endpoint after successful authentication. */ - -//const {dirname} = require('path'); -//const appDir = dirname(require.main.filename); -//const config = require(appDir + '/config.conf'); -const path = require('path'); -const config = require(path.join(process.cwd(), './config.conf')); +const { getConfig } = require('./../util.js'); const unauthorizedResponse = (res) => { res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"'); @@ -15,6 +10,7 @@ const unauthorizedResponse = (res) => { }; const basicAuthMiddleware = function (request, response, callback) { + const config = getConfig(); const authHeader = request.headers.authorization; if (!authHeader) { diff --git a/src/util.js b/src/util.js new file mode 100644 index 000000000..a870eb169 --- /dev/null +++ b/src/util.js @@ -0,0 +1,10 @@ +const path = require('path'); + +function getConfig() { + const config = require(path.join(process.cwd(), './config.conf')); + return config; +} + +module.exports = { + getConfig, +};