Merge pull request #1441 from valadaptive/completion-source-refactor
Refactor chat completion source API parameter
This commit is contained in:
commit
d44bb9f1e0
|
@ -1496,6 +1496,7 @@ async function sendOpenAIRequest(type, messages, signal) {
|
|||
'stream': stream,
|
||||
'logit_bias': logit_bias,
|
||||
'stop': getCustomStoppingStrings(openai_max_stop_strings),
|
||||
'chat_completion_source': oai_settings.chat_completion_source,
|
||||
};
|
||||
|
||||
// Empty array will produce a validation error
|
||||
|
@ -1516,7 +1517,6 @@ async function sendOpenAIRequest(type, messages, signal) {
|
|||
}
|
||||
|
||||
if (isClaude) {
|
||||
generate_data['use_claude'] = true;
|
||||
generate_data['top_k'] = Number(oai_settings.top_k_openai);
|
||||
generate_data['exclude_assistant'] = oai_settings.exclude_assistant;
|
||||
generate_data['stop'] = getCustomStoppingStrings(); // Claude shouldn't have limits on stop strings.
|
||||
|
@ -1527,7 +1527,6 @@ async function sendOpenAIRequest(type, messages, signal) {
|
|||
}
|
||||
|
||||
if (isOpenRouter) {
|
||||
generate_data['use_openrouter'] = true;
|
||||
generate_data['top_k'] = Number(oai_settings.top_k_openai);
|
||||
generate_data['use_fallback'] = oai_settings.openrouter_use_fallback;
|
||||
|
||||
|
@ -1537,20 +1536,17 @@ async function sendOpenAIRequest(type, messages, signal) {
|
|||
}
|
||||
|
||||
if (isScale) {
|
||||
generate_data['use_scale'] = true;
|
||||
generate_data['api_url_scale'] = oai_settings.api_url_scale;
|
||||
}
|
||||
|
||||
if (isPalm) {
|
||||
const nameStopString = isImpersonate ? `\n${name2}:` : `\n${name1}:`;
|
||||
const stopStringsLimit = 3; // 5 - 2 (nameStopString and new_chat_prompt)
|
||||
generate_data['use_palm'] = true;
|
||||
generate_data['top_k'] = Number(oai_settings.top_k_openai);
|
||||
generate_data['stop'] = [nameStopString, oai_settings.new_chat_prompt, ...getCustomStoppingStrings(stopStringsLimit)];
|
||||
}
|
||||
|
||||
if (isAI21) {
|
||||
generate_data['use_ai21'] = true;
|
||||
generate_data['top_k'] = Number(oai_settings.top_k_openai);
|
||||
generate_data['count_pen'] = Number(oai_settings.count_pen);
|
||||
generate_data['stop_tokens'] = [name1 + ':', oai_settings.new_chat_prompt, oai_settings.new_group_chat_prompt];
|
||||
|
@ -2463,10 +2459,10 @@ async function getStatusOpen() {
|
|||
let data = {
|
||||
reverse_proxy: oai_settings.reverse_proxy,
|
||||
proxy_password: oai_settings.proxy_password,
|
||||
use_openrouter: oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER,
|
||||
chat_completion_source: oai_settings.chat_completion_source,
|
||||
};
|
||||
|
||||
if (oai_settings.reverse_proxy && !data.use_openrouter) {
|
||||
if (oai_settings.reverse_proxy && oai_settings.chat_completion_source !== chat_completion_sources.OPENROUTER) {
|
||||
validateReverseProxy();
|
||||
}
|
||||
|
||||
|
|
31
server.js
31
server.js
|
@ -217,7 +217,7 @@ const AVATAR_WIDTH = 400;
|
|||
const AVATAR_HEIGHT = 600;
|
||||
const jsonParser = express.json({ limit: '200mb' });
|
||||
const urlencodedParser = express.urlencoded({ extended: true, limit: '200mb' });
|
||||
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES } = require('./src/constants');
|
||||
const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES } = require('./src/constants');
|
||||
const { TavernCardValidator } = require('./src/validator/TavernCardValidator');
|
||||
|
||||
// CSRF Protection //
|
||||
|
@ -2794,7 +2794,7 @@ app.post('/getstatus_openai', jsonParser, async function (request, response_gets
|
|||
let api_key_openai;
|
||||
let headers;
|
||||
|
||||
if (request.body.use_openrouter == false) {
|
||||
if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) {
|
||||
api_url = new URL(request.body.reverse_proxy || API_OPENAI).toString();
|
||||
api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
|
||||
headers = {};
|
||||
|
@ -2822,7 +2822,7 @@ app.post('/getstatus_openai', jsonParser, async function (request, response_gets
|
|||
const data = await response.json();
|
||||
response_getstatus_openai.send(data);
|
||||
|
||||
if (request.body.use_openrouter && Array.isArray(data?.data)) {
|
||||
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER && Array.isArray(data?.data)) {
|
||||
let models = [];
|
||||
|
||||
data.data.forEach(model => {
|
||||
|
@ -3237,20 +3237,11 @@ async function sendPalmRequest(request, response) {
|
|||
app.post('/generate_openai', jsonParser, function (request, response_generate_openai) {
|
||||
if (!request.body) return response_generate_openai.status(400).send({ error: true });
|
||||
|
||||
if (request.body.use_claude) {
|
||||
return sendClaudeRequest(request, response_generate_openai);
|
||||
}
|
||||
|
||||
if (request.body.use_scale) {
|
||||
return sendScaleRequest(request, response_generate_openai);
|
||||
}
|
||||
|
||||
if (request.body.use_ai21) {
|
||||
return sendAI21Request(request, response_generate_openai);
|
||||
}
|
||||
|
||||
if (request.body.use_palm) {
|
||||
return sendPalmRequest(request, response_generate_openai);
|
||||
switch (request.body.chat_completion_source) {
|
||||
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai);
|
||||
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai);
|
||||
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai);
|
||||
case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai);
|
||||
}
|
||||
|
||||
let api_url;
|
||||
|
@ -3258,7 +3249,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
|
|||
let headers;
|
||||
let bodyParams;
|
||||
|
||||
if (!request.body.use_openrouter) {
|
||||
if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) {
|
||||
api_url = new URL(request.body.reverse_proxy || API_OPENAI).toString();
|
||||
api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
|
||||
headers = {};
|
||||
|
@ -3290,7 +3281,9 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op
|
|||
|
||||
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
|
||||
const textPrompt = isTextCompletion ? convertChatMLPrompt(request.body.messages) : '';
|
||||
const endpointUrl = isTextCompletion && !request.body.use_openrouter ? `${api_url}/completions` : `${api_url}/chat/completions`;
|
||||
const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ?
|
||||
`${api_url}/completions` :
|
||||
`${api_url}/chat/completions`;
|
||||
|
||||
const controller = new AbortController();
|
||||
request.socket.removeAllListeners('close');
|
||||
|
|
|
@ -132,6 +132,16 @@ const PALM_SAFETY = [
|
|||
},
|
||||
];
|
||||
|
||||
const CHAT_COMPLETION_SOURCES = {
|
||||
OPENAI: 'openai',
|
||||
WINDOWAI: 'windowai',
|
||||
CLAUDE: 'claude',
|
||||
SCALE: 'scale',
|
||||
OPENROUTER: 'openrouter',
|
||||
AI21: 'ai21',
|
||||
PALM: 'palm',
|
||||
};
|
||||
|
||||
const UPLOADS_PATH = './uploads';
|
||||
|
||||
// TODO: this is copied from the client code; there should be a way to de-duplicate it eventually
|
||||
|
@ -149,4 +159,5 @@ module.exports = {
|
|||
UPLOADS_PATH,
|
||||
PALM_SAFETY,
|
||||
TEXTGEN_TYPES,
|
||||
CHAT_COMPLETION_SOURCES,
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue