From ebde9c2c1fe10b06f0fb1e2c7f2f2c40a3295190 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Sun, 3 Dec 2023 15:03:32 -0500 Subject: [PATCH] Copy chat_completion_sources enum to server code --- server.js | 18 +++++++++--------- src/constants.js | 11 +++++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/server.js b/server.js index cda208ba7..58f7b72bc 100644 --- a/server.js +++ b/server.js @@ -217,7 +217,7 @@ const AVATAR_WIDTH = 400; const AVATAR_HEIGHT = 600; const jsonParser = express.json({ limit: '200mb' }); const urlencodedParser = express.urlencoded({ extended: true, limit: '200mb' }); -const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES } = require('./src/constants'); +const { DIRECTORIES, UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, CHAT_COMPLETION_SOURCES } = require('./src/constants'); const { TavernCardValidator } = require('./src/validator/TavernCardValidator'); // CSRF Protection // @@ -2794,7 +2794,7 @@ app.post('/getstatus_openai', jsonParser, async function (request, response_gets let api_key_openai; let headers; - if (request.body.chat_completion_source !== 'openrouter') { + if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) { api_url = new URL(request.body.reverse_proxy || API_OPENAI).toString(); api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI); headers = {}; @@ -2822,7 +2822,7 @@ app.post('/getstatus_openai', jsonParser, async function (request, response_gets const data = await response.json(); response_getstatus_openai.send(data); - if (request.body.chat_completion_source === 'openrouter' && Array.isArray(data?.data)) { + if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER && Array.isArray(data?.data)) { let models = []; data.data.forEach(model => { @@ -3238,10 +3238,10 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op if (!request.body) return response_generate_openai.status(400).send({ error: true }); switch (request.body.chat_completion_source) { - case 'claude': return sendClaudeRequest(request, response_generate_openai); - case 'scale': return sendScaleRequest(request, response_generate_openai); - case 'ai21': return sendAI21Request(request, response_generate_openai); - case 'palm': return sendPalmRequest(request, response_generate_openai); + case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response_generate_openai); + case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response_generate_openai); + case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response_generate_openai); + case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response_generate_openai); } let api_url; @@ -3249,7 +3249,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op let headers; let bodyParams; - if (request.body.chat_completion_source !== 'openrouter') { + if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) { api_url = new URL(request.body.reverse_proxy || API_OPENAI).toString(); api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI); headers = {}; @@ -3281,7 +3281,7 @@ app.post('/generate_openai', jsonParser, function (request, response_generate_op const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string'; const textPrompt = isTextCompletion ? convertChatMLPrompt(request.body.messages) : ''; - const endpointUrl = isTextCompletion && request.body.chat_completion_source !== 'openrouter' ? + const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ? `${api_url}/completions` : `${api_url}/chat/completions`; diff --git a/src/constants.js b/src/constants.js index 978cc7c43..eac932d69 100644 --- a/src/constants.js +++ b/src/constants.js @@ -132,6 +132,16 @@ const PALM_SAFETY = [ }, ]; +const CHAT_COMPLETION_SOURCES = { + OPENAI: 'openai', + WINDOWAI: 'windowai', + CLAUDE: 'claude', + SCALE: 'scale', + OPENROUTER: 'openrouter', + AI21: 'ai21', + PALM: 'palm', +}; + const UPLOADS_PATH = './uploads'; // TODO: this is copied from the client code; there should be a way to de-duplicate it eventually @@ -149,4 +159,5 @@ module.exports = { UPLOADS_PATH, PALM_SAFETY, TEXTGEN_TYPES, + CHAT_COMPLETION_SOURCES, };