Fix config access. Add Top K to classification results

This commit is contained in:
Cohee 2023-09-11 01:49:47 +03:00
parent c76c76410c
commit 7aeb098212
3 changed files with 17 additions and 9 deletions

View File

@ -1,5 +1,6 @@
import { pipeline, env } from 'sillytavern-transformers'; import { pipeline, env } from 'sillytavern-transformers';
import path from 'path'; import path from 'path';
import { getConfig } from './util.js';
// Limit the number of threads to 1 to avoid issues on Android // Limit the number of threads to 1 to avoid issues on Android
env.backends.onnx.wasm.numThreads = 1; env.backends.onnx.wasm.numThreads = 1;
@ -24,10 +25,10 @@ class PipelineAccessor {
const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx'; const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx';
try { try {
const config = require(path.join(process.cwd(), './config.conf')); const config = getConfig();
const model = config?.extras?.classificationModel; const model = config?.extras?.classificationModel;
return model || DEFAULT_MODEL; return model || DEFAULT_MODEL;
} catch { } catch (error) {
console.warn('Failed to read config.conf, using default classification model.'); console.warn('Failed to read config.conf, using default classification model.');
return DEFAULT_MODEL; return DEFAULT_MODEL;
} }
@ -62,7 +63,8 @@ function registerEndpoints(app, jsonParser) {
return cacheObject[text]; return cacheObject[text];
} else { } else {
const pipe = await pipelineAccessor.get(); const pipe = await pipelineAccessor.get();
const result = await pipe(text); const result = await pipe(text, { topk: 5 });
result.sort((a, b) => b.score - a.score);
cacheObject[text] = result; cacheObject[text] = result;
return result; return result;
} }

View File

@ -2,12 +2,7 @@
* When applied, this middleware will ensure the request contains the required header for basic authentication and only * When applied, this middleware will ensure the request contains the required header for basic authentication and only
* allow access to the endpoint after successful authentication. * allow access to the endpoint after successful authentication.
*/ */
const { getConfig } = require('./../util.js');
//const {dirname} = require('path');
//const appDir = dirname(require.main.filename);
//const config = require(appDir + '/config.conf');
const path = require('path');
const config = require(path.join(process.cwd(), './config.conf'));
const unauthorizedResponse = (res) => { const unauthorizedResponse = (res) => {
res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"'); res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"');
@ -15,6 +10,7 @@ const unauthorizedResponse = (res) => {
}; };
const basicAuthMiddleware = function (request, response, callback) { const basicAuthMiddleware = function (request, response, callback) {
const config = getConfig();
const authHeader = request.headers.authorization; const authHeader = request.headers.authorization;
if (!authHeader) { if (!authHeader) {

10
src/util.js Normal file
View File

@ -0,0 +1,10 @@
const path = require('path');
function getConfig() {
const config = require(path.join(process.cwd(), './config.conf'));
return config;
}
module.exports = {
getConfig,
};