Fix config access. Add Top K to classification results

This commit is contained in:
Cohee 2023-09-11 01:49:47 +03:00
parent c76c76410c
commit 7aeb098212
3 changed files with 17 additions and 9 deletions

View File

@ -1,5 +1,6 @@
import { pipeline, env } from 'sillytavern-transformers';
import path from 'path';
import { getConfig } from './util.js';
// Limit the number of threads to 1 to avoid issues on Android
env.backends.onnx.wasm.numThreads = 1;
@ -24,10 +25,10 @@ class PipelineAccessor {
const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx';
try {
const config = require(path.join(process.cwd(), './config.conf'));
const config = getConfig();
const model = config?.extras?.classificationModel;
return model || DEFAULT_MODEL;
} catch {
} catch (error) {
console.warn('Failed to read config.conf, using default classification model.');
return DEFAULT_MODEL;
}
@ -62,7 +63,8 @@ function registerEndpoints(app, jsonParser) {
return cacheObject[text];
} else {
const pipe = await pipelineAccessor.get();
const result = await pipe(text);
const result = await pipe(text, { topk: 5 });
result.sort((a, b) => b.score - a.score);
cacheObject[text] = result;
return result;
}

View File

@ -2,12 +2,7 @@
* When applied, this middleware will ensure the request contains the required header for basic authentication and only
* allow access to the endpoint after successful authentication.
*/
//const {dirname} = require('path');
//const appDir = dirname(require.main.filename);
//const config = require(appDir + '/config.conf');
const path = require('path');
const config = require(path.join(process.cwd(), './config.conf'));
const { getConfig } = require('./../util.js');
const unauthorizedResponse = (res) => {
res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"');
@ -15,6 +10,7 @@ const unauthorizedResponse = (res) => {
};
const basicAuthMiddleware = function (request, response, callback) {
const config = getConfig();
const authHeader = request.headers.authorization;
if (!authHeader) {

10
src/util.js Normal file
View File

@ -0,0 +1,10 @@
const path = require('path');
function getConfig() {
const config = require(path.join(process.cwd(), './config.conf'));
return config;
}
module.exports = {
getConfig,
};