mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-04-15 19:27:20 +02:00
Fix config access. Add Top K to classification results
This commit is contained in:
parent
c76c76410c
commit
7aeb098212
@ -1,5 +1,6 @@
|
|||||||
import { pipeline, env } from 'sillytavern-transformers';
|
import { pipeline, env } from 'sillytavern-transformers';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
|
import { getConfig } from './util.js';
|
||||||
|
|
||||||
// Limit the number of threads to 1 to avoid issues on Android
|
// Limit the number of threads to 1 to avoid issues on Android
|
||||||
env.backends.onnx.wasm.numThreads = 1;
|
env.backends.onnx.wasm.numThreads = 1;
|
||||||
@ -24,10 +25,10 @@ class PipelineAccessor {
|
|||||||
const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx';
|
const DEFAULT_MODEL = 'Cohee/distilbert-base-uncased-go-emotions-onnx';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const config = require(path.join(process.cwd(), './config.conf'));
|
const config = getConfig();
|
||||||
const model = config?.extras?.classificationModel;
|
const model = config?.extras?.classificationModel;
|
||||||
return model || DEFAULT_MODEL;
|
return model || DEFAULT_MODEL;
|
||||||
} catch {
|
} catch (error) {
|
||||||
console.warn('Failed to read config.conf, using default classification model.');
|
console.warn('Failed to read config.conf, using default classification model.');
|
||||||
return DEFAULT_MODEL;
|
return DEFAULT_MODEL;
|
||||||
}
|
}
|
||||||
@ -62,7 +63,8 @@ function registerEndpoints(app, jsonParser) {
|
|||||||
return cacheObject[text];
|
return cacheObject[text];
|
||||||
} else {
|
} else {
|
||||||
const pipe = await pipelineAccessor.get();
|
const pipe = await pipelineAccessor.get();
|
||||||
const result = await pipe(text);
|
const result = await pipe(text, { topk: 5 });
|
||||||
|
result.sort((a, b) => b.score - a.score);
|
||||||
cacheObject[text] = result;
|
cacheObject[text] = result;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,7 @@
|
|||||||
* When applied, this middleware will ensure the request contains the required header for basic authentication and only
|
* When applied, this middleware will ensure the request contains the required header for basic authentication and only
|
||||||
* allow access to the endpoint after successful authentication.
|
* allow access to the endpoint after successful authentication.
|
||||||
*/
|
*/
|
||||||
|
const { getConfig } = require('./../util.js');
|
||||||
//const {dirname} = require('path');
|
|
||||||
//const appDir = dirname(require.main.filename);
|
|
||||||
//const config = require(appDir + '/config.conf');
|
|
||||||
const path = require('path');
|
|
||||||
const config = require(path.join(process.cwd(), './config.conf'));
|
|
||||||
|
|
||||||
const unauthorizedResponse = (res) => {
|
const unauthorizedResponse = (res) => {
|
||||||
res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"');
|
res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"');
|
||||||
@ -15,6 +10,7 @@ const unauthorizedResponse = (res) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const basicAuthMiddleware = function (request, response, callback) {
|
const basicAuthMiddleware = function (request, response, callback) {
|
||||||
|
const config = getConfig();
|
||||||
const authHeader = request.headers.authorization;
|
const authHeader = request.headers.authorization;
|
||||||
|
|
||||||
if (!authHeader) {
|
if (!authHeader) {
|
||||||
|
10
src/util.js
Normal file
10
src/util.js
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
const path = require('path');
|
||||||
|
|
||||||
|
function getConfig() {
|
||||||
|
const config = require(path.join(process.cwd(), './config.conf'));
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getConfig,
|
||||||
|
};
|
Loading…
x
Reference in New Issue
Block a user