+
For privacy reasons, your API key will be hidden after you reload the page.
-
-
+
+
Google Model
+
+
diff --git a/public/script.js b/public/script.js
index c9d45baf7..7d4ea3a37 100644
--- a/public/script.js
+++ b/public/script.js
@@ -2557,8 +2557,8 @@ function getCharacterCardFields() {
}
function isStreamingEnabled() {
- const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21, chat_completion_sources.PALM];
- return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source))
+ const noStreamSources = [chat_completion_sources.SCALE, chat_completion_sources.AI21];
+ return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source) && !(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE && oai_settings.google_model.includes('bison')))
|| (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming)
|| (main_api == 'novel' && nai_settings.streaming_novel)
|| (main_api == 'textgenerationwebui' && textgen_settings.streaming));
@@ -5395,7 +5395,7 @@ function changeMainAPI() {
case chat_completion_sources.CLAUDE:
case chat_completion_sources.OPENAI:
case chat_completion_sources.AI21:
- case chat_completion_sources.PALM:
+ case chat_completion_sources.MAKERSUITE:
default:
setupChatCompletionPromptManager(oai_settings);
break;
@@ -7535,9 +7535,9 @@ async function connectAPISlash(_, text) {
source: 'ai21',
button: '#api_button_openai',
},
- 'palm': {
+ 'makersuite': {
selected: 'openai',
- source: 'palm',
+ source: 'makersuite',
button: '#api_button_openai',
},
};
@@ -7826,7 +7826,7 @@ jQuery(async function () {
}
registerSlashCommand('dupe', DupeChar, [], '– duplicates the currently selected character', true, true);
- registerSlashCommand('api', connectAPISlash, [], '(kobold, horde, novel, ooba, tabby, mancer, aphrodite, kcpp, oai, claude, windowai, openrouter, scale, ai21, palm) – connect to an API', true, true);
+ registerSlashCommand('api', connectAPISlash, [], '(kobold, horde, novel, ooba, tabby, mancer, aphrodite, kcpp, oai, claude, windowai, openrouter, scale, ai21, makersuite) – connect to an API', true, true);
registerSlashCommand('impersonate', doImpersonate, ['imp'], '– calls an impersonation response', true, true);
registerSlashCommand('delchat', doDeleteChat, [], '– deletes the current chat', true, true);
registerSlashCommand('closechat', doCloseChat, [], '– closes the current chat', true, true);
diff --git a/public/scripts/RossAscends-mods.js b/public/scripts/RossAscends-mods.js
index f5bad628b..0cf4fa5c3 100644
--- a/public/scripts/RossAscends-mods.js
+++ b/public/scripts/RossAscends-mods.js
@@ -415,7 +415,7 @@ function RA_autoconnect(PrevApi) {
|| (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI)
|| (secret_state[SECRET_KEYS.OPENROUTER] && oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER)
|| (secret_state[SECRET_KEYS.AI21] && oai_settings.chat_completion_source == chat_completion_sources.AI21)
- || (secret_state[SECRET_KEYS.PALM] && oai_settings.chat_completion_source == chat_completion_sources.PALM)
+ || (secret_state[SECRET_KEYS.MAKERSUITE] && oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE)
) {
$('#api_button_openai').trigger('click');
}
diff --git a/public/scripts/extensions/caption/index.js b/public/scripts/extensions/caption/index.js
index 395fbf8d4..aa666a232 100644
--- a/public/scripts/extensions/caption/index.js
+++ b/public/scripts/extensions/caption/index.js
@@ -134,7 +134,7 @@ async function doCaptionRequest(base64Img, fileData) {
case 'horde':
return await captionHorde(base64Img);
case 'multimodal':
- return await captionMultimodal(fileData);
+ return await captionMultimodal(extension_settings.caption.multimodal_api === 'google' ? base64Img : fileData);
default:
throw new Error('Unknown caption source.');
}
@@ -273,6 +273,7 @@ jQuery(function () {
(modules.includes('caption') && extension_settings.caption.source === 'extras') ||
(extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'openai' && secret_state[SECRET_KEYS.OPENAI]) ||
(extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'openrouter' && secret_state[SECRET_KEYS.OPENROUTER]) ||
+ (extension_settings.caption.source === 'multimodal' && extension_settings.caption.multimodal_api === 'google' && secret_state[SECRET_KEYS.MAKERSUITE]) ||
extension_settings.caption.source === 'local' ||
extension_settings.caption.source === 'horde';
@@ -328,7 +329,7 @@ jQuery(function () {
@@ -338,12 +339,14 @@ jQuery(function () {
diff --git a/public/scripts/extensions/shared.js b/public/scripts/extensions/shared.js
index 1eb4cd905..9058204ec 100644
--- a/public/scripts/extensions/shared.js
+++ b/public/scripts/extensions/shared.js
@@ -18,22 +18,35 @@ export async function getMultimodalCaption(base64Img, prompt) {
throw new Error('OpenRouter API key is not set.');
}
- // OpenRouter has a payload limit of ~2MB
- const base64Bytes = base64Img.length * 0.75;
- const compressionLimit = 2 * 1024 * 1024;
- if (extension_settings.caption.multimodal_api === 'openrouter' && base64Bytes > compressionLimit) {
- const maxSide = 1024;
- base64Img = await createThumbnail(base64Img, maxSide, maxSide, 'image/jpeg');
+ if (extension_settings.caption.multimodal_api === 'google' && !secret_state[SECRET_KEYS.MAKERSUITE]) {
+ throw new Error('MakerSuite API key is not set.');
}
- const apiResult = await fetch('/api/openai/caption-image', {
+ // OpenRouter has a payload limit of ~2MB. Google is 4MB, but we love democracy.
+ const isGoogle = extension_settings.caption.multimodal_api === 'google';
+ const base64Bytes = base64Img.length * 0.75;
+ const compressionLimit = 2 * 1024 * 1024;
+ if (['google', 'openrouter'].includes(extension_settings.caption.multimodal_api) && base64Bytes > compressionLimit) {
+ const maxSide = 1024;
+ base64Img = await createThumbnail(base64Img, maxSide, maxSide, 'image/jpeg');
+
+ if (isGoogle) {
+ base64Img = base64Img.split(',')[1];
+ }
+ }
+
+ const apiResult = await fetch(`/api/${isGoogle ? 'google' : 'openai'}/caption-image`, {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
image: base64Img,
prompt: prompt,
- api: extension_settings.caption.multimodal_api || 'openai',
- model: extension_settings.caption.multimodal_model || 'gpt-4-vision-preview',
+ ...(isGoogle
+ ? {}
+ : {
+ api: extension_settings.caption.multimodal_api || 'openai',
+ model: extension_settings.caption.multimodal_model || 'gpt-4-vision-preview',
+ }),
}),
});
diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js
index 9935d43e4..204e268c2 100644
--- a/public/scripts/extensions/stable-diffusion/index.js
+++ b/public/scripts/extensions/stable-diffusion/index.js
@@ -1756,22 +1756,28 @@ async function generateMultimodalPrompt(generationType, quietPrompt) {
}
}
- const response = await fetch(avatarUrl);
+ try {
+ const response = await fetch(avatarUrl);
- if (!response.ok) {
- throw new Error('Could not fetch avatar image.');
- }
+ if (!response.ok) {
+ throw new Error('Could not fetch avatar image.');
+ }
- const avatarBlob = await response.blob();
- const avatarBase64 = await getBase64Async(avatarBlob);
+ const avatarBlob = await response.blob();
+ const avatarBase64 = await getBase64Async(avatarBlob);
- const caption = await getMultimodalCaption(avatarBase64, quietPrompt);
+ const caption = await getMultimodalCaption(avatarBase64, quietPrompt);
- if (!caption) {
+ if (!caption) {
+ throw new Error('No caption returned from the API.');
+ }
+
+ return caption;
+ } catch (error) {
+ console.error(error);
+ toastr.error('Multimodal captioning failed. Please try again.', 'Image Generation');
throw new Error('Multimodal captioning failed.');
}
-
- return caption;
}
/**
diff --git a/public/scripts/extensions/vectors/index.js b/public/scripts/extensions/vectors/index.js
index e02501f2e..214b1d887 100644
--- a/public/scripts/extensions/vectors/index.js
+++ b/public/scripts/extensions/vectors/index.js
@@ -394,7 +394,7 @@ async function getSavedHashes(collectionId) {
*/
async function insertVectorItems(collectionId, items) {
if (settings.source === 'openai' && !secret_state[SECRET_KEYS.OPENAI] ||
- settings.source === 'palm' && !secret_state[SECRET_KEYS.PALM]) {
+ settings.source === 'palm' && !secret_state[SECRET_KEYS.MAKERSUITE]) {
throw new Error('Vectors: API key missing', { cause: 'api_key_missing' });
}
diff --git a/public/scripts/openai.js b/public/scripts/openai.js
index b8cdc818f..9007017f6 100644
--- a/public/scripts/openai.js
+++ b/public/scripts/openai.js
@@ -37,8 +37,8 @@ import {
chatCompletionDefaultPrompts,
INJECTION_POSITION,
Prompt,
- promptManagerDefaultPromptOrders,
PromptManager,
+ promptManagerDefaultPromptOrders,
} from './PromptManager.js';
import { getCustomStoppingStrings, persona_description_positions, power_user } from './power-user.js';
@@ -114,7 +114,6 @@ const max_128k = 128 * 1000;
const max_200k = 200 * 1000;
const scale_max = 8191;
const claude_max = 9000; // We have a proper tokenizer, so theoretically could be larger (up to 9k)
-const palm2_max = 7400; // The real context window is 8192, spare some for padding due to using turbo tokenizer
const claude_100k_max = 99000;
let ai21_max = 9200; //can easily fit 9k gpt tokens because j2's tokenizer is efficient af
const unlocked_max = 100 * 1024;
@@ -164,7 +163,7 @@ export const chat_completion_sources = {
SCALE: 'scale',
OPENROUTER: 'openrouter',
AI21: 'ai21',
- PALM: 'palm',
+ MAKERSUITE: 'makersuite',
};
const prefixMap = selected_group ? {
@@ -207,6 +206,7 @@ const default_settings = {
personality_format: default_personality_format,
openai_model: 'gpt-3.5-turbo',
claude_model: 'claude-instant-v1',
+ google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
windowai_model: '',
openrouter_model: openrouter_website_model,
@@ -223,6 +223,7 @@ const default_settings = {
proxy_password: '',
assistant_prefill: '',
use_ai21_tokenizer: false,
+ use_google_tokenizer: false,
exclude_assistant: false,
use_alt_scale: false,
squash_system_messages: false,
@@ -260,6 +261,7 @@ const oai_settings = {
personality_format: default_personality_format,
openai_model: 'gpt-3.5-turbo',
claude_model: 'claude-instant-v1',
+ google_model: 'gemini-pro',
ai21_model: 'j2-ultra',
windowai_model: '',
openrouter_model: openrouter_website_model,
@@ -276,6 +278,7 @@ const oai_settings = {
proxy_password: '',
assistant_prefill: '',
use_ai21_tokenizer: false,
+ use_google_tokenizer: false,
exclude_assistant: false,
use_alt_scale: false,
squash_system_messages: false,
@@ -1252,8 +1255,8 @@ function getChatCompletionModel() {
return oai_settings.windowai_model;
case chat_completion_sources.SCALE:
return '';
- case chat_completion_sources.PALM:
- return '';
+ case chat_completion_sources.MAKERSUITE:
+ return oai_settings.google_model;
case chat_completion_sources.OPENROUTER:
return oai_settings.openrouter_model !== openrouter_website_model ? oai_settings.openrouter_model : null;
case chat_completion_sources.AI21:
@@ -1443,20 +1446,20 @@ async function sendOpenAIRequest(type, messages, signal) {
const isOpenRouter = oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER;
const isScale = oai_settings.chat_completion_source == chat_completion_sources.SCALE;
const isAI21 = oai_settings.chat_completion_source == chat_completion_sources.AI21;
- const isPalm = oai_settings.chat_completion_source == chat_completion_sources.PALM;
+ const isGoogle = oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE;
const isOAI = oai_settings.chat_completion_source == chat_completion_sources.OPENAI;
const isTextCompletion = (isOAI && textCompletionModels.includes(oai_settings.openai_model)) || (isOpenRouter && oai_settings.openrouter_force_instruct && power_user.instruct.enabled);
const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue';
- const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !isPalm;
+ const stream = oai_settings.stream_openai && !isQuiet && !isScale && !isAI21 && !(isGoogle && oai_settings.google_model.includes('bison'));
if (isTextCompletion && isOpenRouter) {
messages = convertChatCompletionToInstruct(messages, type);
replaceItemizedPromptText(messageId, messages);
}
- if (isAI21 || isPalm) {
+ if (isAI21) {
const joinedMsgs = messages.reduce((acc, obj) => {
const prefix = prefixMap[obj.role];
return acc + (prefix ? (selected_group ? '\n' : prefix + ' ') : '') + obj.content + '\n';
@@ -1539,7 +1542,7 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['api_url_scale'] = oai_settings.api_url_scale;
}
- if (isPalm) {
+ if (isGoogle) {
const nameStopString = isImpersonate ? `\n${name2}:` : `\n${name1}:`;
const stopStringsLimit = 3; // 5 - 2 (nameStopString and new_chat_prompt)
generate_data['top_k'] = Number(oai_settings.top_k_openai);
@@ -1568,23 +1571,26 @@ async function sendOpenAIRequest(type, messages, signal) {
tryParseStreamingError(response, await response.text());
throw new Error(`Got response status ${response.status}`);
}
-
if (stream) {
- const eventStream = new EventSourceStream();
- response.body.pipeThrough(eventStream);
- const reader = eventStream.readable.getReader();
+ let reader;
+ let isSSEStream = oai_settings.chat_completion_source !== chat_completion_sources.MAKERSUITE;
+ if (isSSEStream) {
+ const eventStream = new EventSourceStream();
+ response.body.pipeThrough(eventStream);
+ reader = eventStream.readable.getReader();
+ } else {
+ reader = response.body.getReader();
+ }
return async function* streamData() {
let text = '';
+ let utf8Decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) return;
- if (value.data === '[DONE]') return;
-
- tryParseStreamingError(response, value.data);
-
- // the first and last messages are undefined, protect against that
- text += getStreamingReply(JSON.parse(value.data));
-
+ const rawData = isSSEStream ? value.data : utf8Decoder.decode(value, { stream: true });
+ if (isSSEStream && rawData === '[DONE]') return;
+ tryParseStreamingError(response, rawData);
+ text += getStreamingReply(JSON.parse(rawData));
yield { text, swipes: [] };
}
};
@@ -1606,6 +1612,8 @@ async function sendOpenAIRequest(type, messages, signal) {
function getStreamingReply(data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return data?.completion || '';
+ } else if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
+ return data?.candidates[0].content.parts[0].text || '';
} else {
return data.choices[0]?.delta?.content || data.choices[0]?.message?.content || data.choices[0]?.text || '';
}
@@ -1787,13 +1795,15 @@ class Message {
async addImage(image) {
const textContent = this.content;
const isDataUrl = isDataURL(image);
-
if (!isDataUrl) {
try {
const response = await fetch(image, { method: 'GET', cache: 'force-cache' });
if (!response.ok) throw new Error('Failed to fetch image');
const blob = await response.blob();
image = await getBase64Async(blob);
+ if (oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE) {
+ image = image.split(',')[1];
+ }
} catch (error) {
console.error('Image adding skipped', error);
return;
@@ -2290,6 +2300,7 @@ function loadOpenAISettings(data, settings) {
oai_settings.openrouter_use_fallback = settings.openrouter_use_fallback ?? default_settings.openrouter_use_fallback;
oai_settings.openrouter_force_instruct = settings.openrouter_force_instruct ?? default_settings.openrouter_force_instruct;
oai_settings.ai21_model = settings.ai21_model ?? default_settings.ai21_model;
+ oai_settings.google_model = settings.google_model ?? default_settings.google_model;
oai_settings.chat_completion_source = settings.chat_completion_source ?? default_settings.chat_completion_source;
oai_settings.api_url_scale = settings.api_url_scale ?? default_settings.api_url_scale;
oai_settings.show_external_models = settings.show_external_models ?? default_settings.show_external_models;
@@ -2311,6 +2322,7 @@ function loadOpenAISettings(data, settings) {
if (settings.names_in_completion !== undefined) oai_settings.names_in_completion = !!settings.names_in_completion;
if (settings.openai_model !== undefined) oai_settings.openai_model = settings.openai_model;
if (settings.use_ai21_tokenizer !== undefined) { oai_settings.use_ai21_tokenizer = !!settings.use_ai21_tokenizer; oai_settings.use_ai21_tokenizer ? ai21_max = 8191 : ai21_max = 9200; }
+ if (settings.use_google_tokenizer !== undefined) oai_settings.use_google_tokenizer = !!settings.use_google_tokenizer;
if (settings.exclude_assistant !== undefined) oai_settings.exclude_assistant = !!settings.exclude_assistant;
if (settings.use_alt_scale !== undefined) { oai_settings.use_alt_scale = !!settings.use_alt_scale; updateScaleForm(); }
$('#stream_toggle').prop('checked', oai_settings.stream_openai);
@@ -2326,6 +2338,8 @@ function loadOpenAISettings(data, settings) {
$(`#model_claude_select option[value="${oai_settings.claude_model}"`).attr('selected', true);
$('#model_windowai_select').val(oai_settings.windowai_model);
$(`#model_windowai_select option[value="${oai_settings.windowai_model}"`).attr('selected', true);
+ $('#model_google_select').val(oai_settings.google_model);
+ $(`#model_google_select option[value="${oai_settings.google_model}"`).attr('selected', true);
$('#model_ai21_select').val(oai_settings.ai21_model);
$(`#model_ai21_select option[value="${oai_settings.ai21_model}"`).attr('selected', true);
$('#openai_max_context').val(oai_settings.openai_max_context);
@@ -2341,6 +2355,7 @@ function loadOpenAISettings(data, settings) {
$('#openai_show_external_models').prop('checked', oai_settings.show_external_models);
$('#openai_external_category').toggle(oai_settings.show_external_models);
$('#use_ai21_tokenizer').prop('checked', oai_settings.use_ai21_tokenizer);
+ $('#use_google_tokenizer').prop('checked', oai_settings.use_google_tokenizer);
$('#exclude_assistant').prop('checked', oai_settings.exclude_assistant);
$('#scale-alt').prop('checked', oai_settings.use_alt_scale);
$('#openrouter_use_fallback').prop('checked', oai_settings.openrouter_use_fallback);
@@ -2396,6 +2411,11 @@ function loadOpenAISettings(data, settings) {
}
$('#openai_logit_bias_preset').trigger('change');
+ // Upgrade Palm to Makersuite
+ if (oai_settings.chat_completion_source === 'palm') {
+ oai_settings.chat_completion_source = chat_completion_sources.MAKERSUITE;
+ }
+
$('#chat_completion_source').val(oai_settings.chat_completion_source).trigger('change');
$('#oai_max_context_unlocked').prop('checked', oai_settings.max_context_unlocked);
}
@@ -2416,7 +2436,7 @@ async function getStatusOpen() {
return resultCheckStatus();
}
- const noValidateSources = [chat_completion_sources.SCALE, chat_completion_sources.CLAUDE, chat_completion_sources.AI21, chat_completion_sources.PALM];
+ const noValidateSources = [chat_completion_sources.SCALE, chat_completion_sources.CLAUDE, chat_completion_sources.AI21, chat_completion_sources.MAKERSUITE];
if (noValidateSources.includes(oai_settings.chat_completion_source)) {
let status = 'Unable to verify key; press "Test Message" to validate.';
setOnlineStatus(status);
@@ -2499,6 +2519,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
openrouter_group_models: settings.openrouter_group_models,
openrouter_sort_models: settings.openrouter_sort_models,
ai21_model: settings.ai21_model,
+ google_model: settings.google_model,
temperature: settings.temp_openai,
frequency_penalty: settings.freq_pen_openai,
presence_penalty: settings.pres_pen_openai,
@@ -2532,6 +2553,7 @@ async function saveOpenAIPreset(name, settings, triggerUi = true) {
show_external_models: settings.show_external_models,
assistant_prefill: settings.assistant_prefill,
use_ai21_tokenizer: settings.use_ai21_tokenizer,
+ use_google_tokenizer: settings.use_google_tokenizer,
exclude_assistant: settings.exclude_assistant,
use_alt_scale: settings.use_alt_scale,
squash_system_messages: settings.squash_system_messages,
@@ -2868,6 +2890,7 @@ function onSettingsPresetChange() {
openrouter_group_models: ['#openrouter_group_models', 'openrouter_group_models', false],
openrouter_sort_models: ['#openrouter_sort_models', 'openrouter_sort_models', false],
ai21_model: ['#model_ai21_select', 'ai21_model', false],
+ google_model: ['#model_google_select', 'google_model', false],
openai_max_context: ['#openai_max_context', 'openai_max_context', false],
openai_max_tokens: ['#openai_max_tokens', 'openai_max_tokens', false],
wrap_in_quotes: ['#wrap_in_quotes', 'wrap_in_quotes', true],
@@ -2892,6 +2915,7 @@ function onSettingsPresetChange() {
proxy_password: ['#openai_proxy_password', 'proxy_password', false],
assistant_prefill: ['#claude_assistant_prefill', 'assistant_prefill', false],
use_ai21_tokenizer: ['#use_ai21_tokenizer', 'use_ai21_tokenizer', true],
+ use_google_tokenizer: ['#use_google_tokenizer', 'use_google_tokenizer', true],
exclude_assistant: ['#exclude_assistant', 'exclude_assistant', true],
use_alt_scale: ['#use_alt_scale', 'use_alt_scale', true],
squash_system_messages: ['#squash_system_messages', 'squash_system_messages', true],
@@ -3000,7 +3024,7 @@ function getMaxContextWindowAI(value) {
return max_8k;
}
else if (value.includes('palm-2')) {
- return palm2_max;
+ return max_8k;
}
else if (value.includes('GPT-NeoXT')) {
return max_2k;
@@ -3045,6 +3069,11 @@ async function onModelChange() {
oai_settings.ai21_model = value;
}
+ if ($(this).is('#model_google_select')) {
+ console.log('Google model changed to', value);
+ oai_settings.google_model = value;
+ }
+
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
if (oai_settings.max_context_unlocked) {
$('#openai_max_context').attr('max', unlocked_max);
@@ -3055,13 +3084,18 @@ async function onModelChange() {
$('#openai_max_context').val(oai_settings.openai_max_context).trigger('input');
}
- if (oai_settings.chat_completion_source == chat_completion_sources.PALM) {
+ if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
if (oai_settings.max_context_unlocked) {
$('#openai_max_context').attr('max', unlocked_max);
+ } else if (value === 'gemini-pro') {
+ $('#openai_max_context').attr('max', max_32k);
+ } else if (value === 'gemini-pro-vision') {
+ $('#openai_max_context').attr('max', max_16k);
} else {
- $('#openai_max_context').attr('max', palm2_max);
+ $('#openai_max_context').attr('max', max_8k);
}
-
+ oai_settings.temp_openai = Math.min(claude_max_temp, oai_settings.temp_openai);
+ $('#temp_openai').attr('max', claude_max_temp).val(oai_settings.temp_openai).trigger('input');
oai_settings.openai_max_context = Math.min(Number($('#openai_max_context').attr('max')), oai_settings.openai_max_context);
$('#openai_max_context').val(oai_settings.openai_max_context).trigger('input');
}
@@ -3254,15 +3288,15 @@ async function onConnectButtonClick(e) {
}
}
- if (oai_settings.chat_completion_source == chat_completion_sources.PALM) {
- const api_key_palm = String($('#api_key_palm').val()).trim();
+ if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
+ const api_key_makersuite = String($('#api_key_makersuite').val()).trim();
- if (api_key_palm.length) {
- await writeSecret(SECRET_KEYS.PALM, api_key_palm);
+ if (api_key_makersuite.length) {
+ await writeSecret(SECRET_KEYS.MAKERSUITE, api_key_makersuite);
}
- if (!secret_state[SECRET_KEYS.PALM]) {
- console.log('No secret key saved for PALM');
+ if (!secret_state[SECRET_KEYS.MAKERSUITE]) {
+ console.log('No secret key saved for MakerSuite');
return;
}
}
@@ -3329,8 +3363,8 @@ function toggleChatCompletionForms() {
else if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
$('#model_scale_select').trigger('change');
}
- else if (oai_settings.chat_completion_source == chat_completion_sources.PALM) {
- $('#model_palm_select').trigger('change');
+ else if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
+ $('#model_google_select').trigger('change');
}
else if (oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER) {
$('#model_openrouter_select').trigger('change');
@@ -3398,6 +3432,7 @@ export function isImageInliningSupported() {
}
const gpt4v = 'gpt-4-vision';
+ const geminiProV = 'gemini-pro-vision';
const llava13b = 'llava-13b';
if (!oai_settings.image_inlining) {
@@ -3407,6 +3442,8 @@ export function isImageInliningSupported() {
switch (oai_settings.chat_completion_source) {
case chat_completion_sources.OPENAI:
return oai_settings.openai_model.includes(gpt4v);
+ case chat_completion_sources.MAKERSUITE:
+ return oai_settings.google_model.includes(geminiProV);
case chat_completion_sources.OPENROUTER:
return oai_settings.openrouter_model.includes(gpt4v) || oai_settings.openrouter_model.includes(llava13b);
default:
@@ -3491,6 +3528,11 @@ $(document).ready(async function () {
saveSettingsDebounced();
});
+ $('#use_google_tokenizer').on('change', function () {
+ oai_settings.use_google_tokenizer = !!$('#use_google_tokenizer').prop('checked');
+ saveSettingsDebounced();
+ });
+
$('#exclude_assistant').on('change', function () {
oai_settings.exclude_assistant = !!$('#exclude_assistant').prop('checked');
$('#claude_assistant_prefill_block').toggle(!oai_settings.exclude_assistant);
@@ -3702,7 +3744,7 @@ $(document).ready(async function () {
$('#model_claude_select').on('change', onModelChange);
$('#model_windowai_select').on('change', onModelChange);
$('#model_scale_select').on('change', onModelChange);
- $('#model_palm_select').on('change', onModelChange);
+ $('#model_google_select').on('change', onModelChange);
$('#model_openrouter_select').on('change', onModelChange);
$('#openrouter_group_models').on('change', onOpenrouterModelSortChange);
$('#openrouter_sort_models').on('change', onOpenrouterModelSortChange);
diff --git a/public/scripts/secrets.js b/public/scripts/secrets.js
index 84279641d..6afb538f1 100644
--- a/public/scripts/secrets.js
+++ b/public/scripts/secrets.js
@@ -12,7 +12,7 @@ export const SECRET_KEYS = {
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
SCALE_COOKIE: 'scale_cookie',
- PALM: 'api_key_palm',
+ MAKERSUITE: 'api_key_makersuite',
SERPAPI: 'api_key_serpapi',
};
@@ -26,7 +26,7 @@ const INPUT_MAP = {
[SECRET_KEYS.SCALE]: '#api_key_scale',
[SECRET_KEYS.AI21]: '#api_key_ai21',
[SECRET_KEYS.SCALE_COOKIE]: '#scale_cookie',
- [SECRET_KEYS.PALM]: '#api_key_palm',
+ [SECRET_KEYS.MAKERSUITE]: '#api_key_makersuite',
[SECRET_KEYS.APHRODITE]: '#api_key_aphrodite',
[SECRET_KEYS.TABBY]: '#api_key_tabby',
};
diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js
index decd0f919..196f3ec9c 100644
--- a/public/scripts/tokenizers.js
+++ b/public/scripts/tokenizers.js
@@ -376,6 +376,10 @@ export function getTokenizerModel() {
}
}
+ if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
+ return oai_settings.google_model;
+ }
+
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return claudeTokenizer;
}
@@ -389,6 +393,15 @@ export function getTokenizerModel() {
*/
export function countTokensOpenAI(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
+ const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
+ let tokenizerEndpoint = '';
+ if (shouldTokenizeAI21) {
+ tokenizerEndpoint = '/api/tokenizers/ai21/count';
+ } else if (shouldTokenizeGoogle) {
+ tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
+ } else {
+ tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
+ }
const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) {
@@ -400,7 +413,7 @@ export function countTokensOpenAI(messages, full = false) {
for (const message of messages) {
const model = getTokenizerModel();
- if (model === 'claude' || shouldTokenizeAI21) {
+ if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
full = true;
}
@@ -416,7 +429,7 @@ export function countTokensOpenAI(messages, full = false) {
jQuery.ajax({
async: false,
type: 'POST', //
- url: shouldTokenizeAI21 ? '/api/tokenizers/ai21/count' : `/api/tokenizers/openai/count?model=${model}`,
+ url: tokenizerEndpoint,
data: JSON.stringify([message]),
dataType: 'json',
contentType: 'application/json',
diff --git a/public/scripts/utils.js b/public/scripts/utils.js
index dfaba74ce..798564c22 100644
--- a/public/scripts/utils.js
+++ b/public/scripts/utils.js
@@ -1030,6 +1030,11 @@ export function loadFileToDocument(url, type) {
* @returns {Promise} A promise that resolves to the thumbnail data URL.
*/
export function createThumbnail(dataUrl, maxWidth, maxHeight, type = 'image/jpeg') {
+ // Someone might pass in a base64 encoded string without the data URL prefix
+ if (!dataUrl.includes('data:')) {
+ dataUrl = `data:image/jpeg;base64,${dataUrl}`;
+ }
+
return new Promise((resolve, reject) => {
const img = new Image();
img.src = dataUrl;
diff --git a/server.js b/server.js
index 41e074551..430a38457 100644
--- a/server.js
+++ b/server.js
@@ -689,6 +689,9 @@ redirect('/downloadbackground', '/api/backgrounds/upload'); // yes, the download
// OpenAI API
app.use('/api/openai', require('./src/endpoints/openai').router);
+//Google API
+app.use('/api/google', require('./src/endpoints/google').router);
+
// Tokenizers
app.use('/api/tokenizers', require('./src/endpoints/tokenizers').router);
diff --git a/src/constants.js b/src/constants.js
index 32ea6fad5..92af44cf2 100644
--- a/src/constants.js
+++ b/src/constants.js
@@ -105,7 +105,26 @@ const UNSAFE_EXTENSIONS = [
'.ws',
];
-const PALM_SAFETY = [
+const GEMINI_SAFETY = [
+ {
+ category: 'HARM_CATEGORY_HARASSMENT',
+ threshold: 'BLOCK_NONE',
+ },
+ {
+ category: 'HARM_CATEGORY_HATE_SPEECH',
+ threshold: 'BLOCK_NONE',
+ },
+ {
+ category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
+ threshold: 'BLOCK_NONE',
+ },
+ {
+ category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
+ threshold: 'BLOCK_NONE',
+ },
+];
+
+const BISON_SAFETY = [
{
category: 'HARM_CATEGORY_DEROGATORY',
threshold: 'BLOCK_NONE',
@@ -139,7 +158,7 @@ const CHAT_COMPLETION_SOURCES = {
SCALE: 'scale',
OPENROUTER: 'openrouter',
AI21: 'ai21',
- PALM: 'palm',
+ MAKERSUITE: 'makersuite',
};
const UPLOADS_PATH = './uploads';
@@ -160,7 +179,8 @@ module.exports = {
DIRECTORIES,
UNSAFE_EXTENSIONS,
UPLOADS_PATH,
- PALM_SAFETY,
+ GEMINI_SAFETY,
+ BISON_SAFETY,
TEXTGEN_TYPES,
CHAT_COMPLETION_SOURCES,
AVATAR_WIDTH,
diff --git a/src/endpoints/backends/chat-completions.js b/src/endpoints/backends/chat-completions.js
index af463bd21..13e09cd56 100644
--- a/src/endpoints/backends/chat-completions.js
+++ b/src/endpoints/backends/chat-completions.js
@@ -1,10 +1,11 @@
const express = require('express');
const fetch = require('node-fetch').default;
+const { Readable } = require('stream');
const { jsonParser } = require('../../express-common');
-const { CHAT_COMPLETION_SOURCES, PALM_SAFETY } = require('../../constants');
+const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util');
-const { convertClaudePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
+const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@@ -151,28 +152,70 @@ async function sendScaleRequest(request, response) {
* @param {express.Request} request Express request
* @param {express.Response} response Express response
*/
-async function sendPalmRequest(request, response) {
- const api_key_palm = readSecret(SECRET_KEYS.PALM);
+async function sendMakerSuiteRequest(request, response) {
+ const apiKey = readSecret(SECRET_KEYS.MAKERSUITE);
- if (!api_key_palm) {
- console.log('Palm API key is missing.');
+ if (!apiKey) {
+ console.log('MakerSuite API key is missing.');
return response.status(400).send({ error: true });
}
- const body = {
- prompt: {
- text: request.body.messages,
- },
+ const model = String(request.body.model);
+ const isGemini = model.includes('gemini');
+ const isText = model.includes('text');
+ const stream = Boolean(request.body.stream) && isGemini;
+
+ const generationConfig = {
stopSequences: request.body.stop,
- safetySettings: PALM_SAFETY,
+ candidateCount: 1,
+ maxOutputTokens: request.body.max_tokens,
temperature: request.body.temperature,
topP: request.body.top_p,
topK: request.body.top_k || undefined,
- maxOutputTokens: request.body.max_tokens,
- candidate_count: 1,
};
- console.log('Palm request:', body);
+ function getGeminiBody() {
+ return {
+ contents: convertGooglePrompt(request.body.messages, model),
+ safetySettings: GEMINI_SAFETY,
+ generationConfig: generationConfig,
+ };
+ }
+
+ function getBisonBody() {
+ const prompt = isText
+ ? ({ text: convertTextCompletionPrompt(request.body.messages) })
+ : ({ messages: convertGooglePrompt(request.body.messages, model) });
+
+ /** @type {any} Shut the lint up */
+ const bisonBody = {
+ ...generationConfig,
+ safetySettings: BISON_SAFETY,
+ candidate_count: 1, // lewgacy spelling
+ prompt: prompt,
+ };
+
+ if (!isText) {
+ delete bisonBody.stopSequences;
+ delete bisonBody.maxOutputTokens;
+ delete bisonBody.safetySettings;
+
+ if (Array.isArray(prompt.messages)) {
+ for (const msg of prompt.messages) {
+ msg.author = msg.role;
+ msg.content = msg.parts[0].text;
+ delete msg.parts;
+ delete msg.role;
+ }
+ }
+ }
+
+ delete bisonBody.candidateCount;
+ return bisonBody;
+ }
+
+ const body = isGemini ? getGeminiBody() : getBisonBody();
+ console.log('MakerSuite request:', body);
try {
const controller = new AbortController();
@@ -181,7 +224,12 @@ async function sendPalmRequest(request, response) {
controller.abort();
});
- const generateResponse = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=${api_key_palm}`, {
+ const apiVersion = isGemini ? 'v1beta' : 'v1beta2';
+ const responseType = isGemini
+ ? (stream ? 'streamGenerateContent' : 'generateContent')
+ : (isText ? 'generateText' : 'generateMessage');
+
+ const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}`, {
body: JSON.stringify(body),
method: 'POST',
headers: {
@@ -190,34 +238,79 @@ async function sendPalmRequest(request, response) {
signal: controller.signal,
timeout: 0,
});
+ // have to do this because of their busted ass streaming endpoint
+ if (stream) {
+ try {
+ let partialData = '';
+ generateResponse.body.on('data', (data) => {
+ const chunk = data.toString();
+ if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) {
+ partialData = chunk.slice(1);
+ } else {
+ partialData += chunk;
+ }
+ while (true) {
+ let json;
+ try {
+ json = JSON.parse(partialData);
+ } catch (e) {
+ break;
+ }
+ response.write(JSON.stringify(json));
+ partialData = '';
+ }
+ });
- if (!generateResponse.ok) {
- console.log(`Palm API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
- return response.status(generateResponse.status).send({ error: true });
- }
+ request.socket.on('close', function () {
+ if (generateResponse.body instanceof Readable) generateResponse.body.destroy();
+ response.end();
+ });
- const generateResponseJson = await generateResponse.json();
- const responseText = generateResponseJson?.candidates?.[0]?.output;
+ generateResponse.body.on('end', () => {
+ console.log('Streaming request finished');
+ response.end();
+ });
- if (!responseText) {
- console.log('Palm API returned no response', generateResponseJson);
- let message = `Palm API returned no response: ${JSON.stringify(generateResponseJson)}`;
-
- // Check for filters
- if (generateResponseJson?.filters?.[0]?.reason) {
- message = `Palm filter triggered: ${generateResponseJson.filters[0].reason}`;
+ } catch (error) {
+ console.log('Error forwarding streaming response:', error);
+ if (!response.headersSent) {
+ return response.status(500).send({ error: true });
+ }
+ }
+ } else {
+ if (!generateResponse.ok) {
+ console.log(`MakerSuite API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
+ return response.status(generateResponse.status).send({ error: true });
}
- return response.send({ error: { message } });
+ const generateResponseJson = await generateResponse.json();
+
+ const candidates = generateResponseJson?.candidates;
+ if (!candidates || candidates.length === 0) {
+ let message = 'MakerSuite API returned no candidate';
+ console.log(message, generateResponseJson);
+ if (generateResponseJson?.promptFeedback?.blockReason) {
+ message += `\nPrompt was blocked due to : ${generateResponseJson.promptFeedback.blockReason}`;
+ }
+ return response.send({ error: { message } });
+ }
+
+ const responseContent = candidates[0].content ?? candidates[0].output;
+ const responseText = typeof responseContent === 'string' ? responseContent : responseContent.parts?.[0]?.text;
+ if (!responseText) {
+ let message = 'MakerSuite Candidate text empty';
+ console.log(message, generateResponseJson);
+ return response.send({ error: { message } });
+ }
+
+ console.log('MakerSuite response:', responseText);
+
+ // Wrap it back to OAI format
+ const reply = { choices: [{ 'message': { 'content': responseText } }] };
+ return response.send(reply);
}
-
- console.log('Palm response:', responseText);
-
- // Wrap it back to OAI format
- const reply = { choices: [{ 'message': { 'content': responseText } }] };
- return response.send(reply);
} catch (error) {
- console.log('Error communicating with Palm API: ', error);
+ console.log('Error communicating with MakerSuite API: ', error);
if (!response.headersSent) {
return response.status(500).send({ error: true });
}
@@ -225,7 +318,7 @@ async function sendPalmRequest(request, response) {
}
/**
- * Sends a request to Google AI API.
+ * Sends a request to AI21 API.
* @param {express.Request} request Express request
* @param {express.Response} response Express response
*/
@@ -457,7 +550,7 @@ router.post('/generate', jsonParser, function (request, response) {
case CHAT_COMPLETION_SOURCES.CLAUDE: return sendClaudeRequest(request, response);
case CHAT_COMPLETION_SOURCES.SCALE: return sendScaleRequest(request, response);
case CHAT_COMPLETION_SOURCES.AI21: return sendAI21Request(request, response);
- case CHAT_COMPLETION_SOURCES.PALM: return sendPalmRequest(request, response);
+ case CHAT_COMPLETION_SOURCES.MAKERSUITE: return sendMakerSuiteRequest(request, response);
}
let apiUrl;
diff --git a/src/endpoints/google.js b/src/endpoints/google.js
new file mode 100644
index 000000000..010b6f0ea
--- /dev/null
+++ b/src/endpoints/google.js
@@ -0,0 +1,66 @@
+const { readSecret, SECRET_KEYS } = require('./secrets');
+const fetch = require('node-fetch').default;
+const express = require('express');
+const { jsonParser } = require('../express-common');
+const { GEMINI_SAFETY } = require('../constants');
+
+const router = express.Router();
+
+router.post('/caption-image', jsonParser, async (request, response) => {
+ try {
+ const mimeType = request.body.image.split(';')[0].split(':')[1];
+ const base64Data = request.body.image.split(',')[1];
+ const url = `https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`;
+ const body = {
+ contents: [{
+ parts: [
+ { text: request.body.prompt },
+ {
+ inlineData: {
+ mimeType: 'image/png', // It needs to specify a MIME type in data if it's not a PNG
+ data: mimeType === 'image/png' ? base64Data : request.body.image,
+ },
+ }],
+ }],
+ safetySettings: GEMINI_SAFETY,
+ generationConfig: { maxOutputTokens: 1000 },
+ };
+
+ console.log('Multimodal captioning request', body);
+
+ const result = await fetch(url, {
+ body: JSON.stringify(body),
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ timeout: 0,
+ });
+
+ if (!result.ok) {
+ const error = await result.json();
+ console.log(`MakerSuite API returned error: ${result.status} ${result.statusText}`, error);
+ return response.status(result.status).send({ error: true });
+ }
+
+ const data = await result.json();
+ console.log('Multimodal captioning response', data);
+
+ const candidates = data?.candidates;
+ if (!candidates) {
+ return response.status(500).send('No candidates found, image was most likely filtered.');
+ }
+
+ const caption = candidates[0].content.parts[0].text;
+ if (!caption) {
+ return response.status(500).send('No caption found');
+ }
+
+ return response.json({ caption });
+ } catch (error) {
+ console.error(error);
+ response.status(500).send('Internal server error');
+ }
+});
+
+module.exports = { router };
diff --git a/src/endpoints/prompt-converters.js b/src/endpoints/prompt-converters.js
index 4ffdb459e..e564b2eb9 100644
--- a/src/endpoints/prompt-converters.js
+++ b/src/endpoints/prompt-converters.js
@@ -72,6 +72,68 @@ function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, with
return requestPrompt;
}
+/**
+ * Convert a prompt from the ChatML objects to the format used by Google MakerSuite models.
+ * @param {object[]} messages Array of messages
+ * @param {string} model Model name
+ * @returns {object[]} Prompt for Google MakerSuite models
+ */
+function convertGooglePrompt(messages, model) {
+ // This is a 1x1 transparent PNG
+ const PNG_PIXEL = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=';
+ const contents = [];
+ let lastRole = '';
+ let currentText = '';
+
+ const isMultimodal = model === 'gemini-pro-vision';
+
+ if (isMultimodal) {
+ const combinedText = messages.map((message) => {
+ const role = message.role === 'assistant' ? 'MODEL: ' : 'USER: ';
+ return role + message.content;
+ }).join('\n\n').trim();
+
+ const imageEntry = messages.find((message) => message.content?.[1]?.image_url);
+ const imageData = imageEntry?.content?.[1]?.image_url?.data ?? PNG_PIXEL;
+ contents.push({
+ parts: [
+ { text: combinedText },
+ {
+ inlineData: {
+ mimeType: 'image/png',
+ data: imageData,
+ },
+ },
+ ],
+ role: 'user',
+ });
+ } else {
+ messages.forEach((message, index) => {
+ const role = message.role === 'assistant' ? 'model' : 'user';
+ if (lastRole === role) {
+ currentText += '\n\n' + message.content;
+ } else {
+ if (currentText !== '') {
+ contents.push({
+ parts: [{ text: currentText.trim() }],
+ role: lastRole,
+ });
+ }
+ currentText = message.content;
+ lastRole = role;
+ }
+ if (index === messages.length - 1) {
+ contents.push({
+ parts: [{ text: currentText.trim() }],
+ role: lastRole,
+ });
+ }
+ });
+ }
+
+ return contents;
+}
+
/**
* Convert a prompt from the ChatML objects to the format used by Text Completion API.
* @param {object[]} messages Array of messages
@@ -99,5 +161,6 @@ function convertTextCompletionPrompt(messages) {
module.exports = {
convertClaudePrompt,
+ convertGooglePrompt,
convertTextCompletionPrompt,
};
diff --git a/src/endpoints/secrets.js b/src/endpoints/secrets.js
index 54687cbeb..5da0cc730 100644
--- a/src/endpoints/secrets.js
+++ b/src/endpoints/secrets.js
@@ -23,7 +23,7 @@ const SECRET_KEYS = {
SCALE_COOKIE: 'scale_cookie',
ONERING_URL: 'oneringtranslator_url',
DEEPLX_URL: 'deeplx_url',
- PALM: 'api_key_palm',
+ MAKERSUITE: 'api_key_makersuite',
SERPAPI: 'api_key_serpapi',
};
@@ -44,6 +44,17 @@ function writeSecret(key, value) {
writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets, null, 4), 'utf-8');
}
+function deleteSecret(key) {
+ if (!fs.existsSync(SECRETS_FILE)) {
+ return;
+ }
+
+ const fileContents = fs.readFileSync(SECRETS_FILE, 'utf-8');
+ const secrets = JSON.parse(fileContents);
+ delete secrets[key];
+ writeFileAtomicSync(SECRETS_FILE, JSON.stringify(secrets, null, 4), 'utf-8');
+}
+
/**
* Reads a secret from the secrets file
* @param {string} key Secret key
@@ -85,6 +96,13 @@ function readSecretState() {
* @returns {void}
*/
function migrateSecrets(settingsFile) {
+ const palmKey = readSecret('api_key_palm');
+ if (palmKey) {
+ console.log('Migrating Palm key...');
+ writeSecret(SECRET_KEYS.MAKERSUITE, palmKey);
+ deleteSecret('api_key_palm');
+ }
+
if (!fs.existsSync(settingsFile)) {
console.log('Settings file does not exist');
return;
diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js
index 38c04f864..b5bbf50ca 100644
--- a/src/endpoints/tokenizers.js
+++ b/src/endpoints/tokenizers.js
@@ -4,7 +4,7 @@ const express = require('express');
const { SentencePieceProcessor } = require('@agnai/sentencepiece-js');
const tiktoken = require('@dqbd/tiktoken');
const { Tokenizer } = require('@agnai/web-tokenizers');
-const { convertClaudePrompt } = require('./prompt-converters');
+const { convertClaudePrompt, convertGooglePrompt } = require('./prompt-converters');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { TEXTGEN_TYPES } = require('../constants');
const { jsonParser } = require('../express-common');
@@ -387,6 +387,26 @@ router.post('/ai21/count', jsonParser, async function (req, res) {
}
});
+router.post('/google/count', jsonParser, async function (req, res) {
+ if (!req.body) return res.sendStatus(400);
+ const options = {
+ method: 'POST',
+ headers: {
+ accept: 'application/json',
+ 'content-type': 'application/json',
+ },
+ body: JSON.stringify({ contents: convertGooglePrompt(req.body) }),
+ };
+ try {
+ const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/${req.query.model}:countTokens?key=${readSecret(SECRET_KEYS.MAKERSUITE)}`, options);
+ const data = await response.json();
+ return res.send({ 'token_count': data?.totalTokens || 0 });
+ } catch (err) {
+ console.error(err);
+ return res.send({ 'token_count': 0 });
+ }
+});
+
router.post('/llama/encode', jsonParser, createSentencepieceEncodingHandler(spp_llama));
router.post('/nerdstash/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd));
router.post('/nerdstash_v2/encode', jsonParser, createSentencepieceEncodingHandler(spp_nerd_v2));
diff --git a/src/endpoints/vectors.js b/src/endpoints/vectors.js
index 387803ccb..e49d157fa 100644
--- a/src/endpoints/vectors.js
+++ b/src/endpoints/vectors.js
@@ -17,7 +17,7 @@ async function getVector(source, text) {
case 'transformers':
return require('../embedding').getTransformersVector(text);
case 'palm':
- return require('../palm-vectors').getPaLMVector(text);
+ return require('../makersuite-vectors').getMakerSuiteVector(text);
}
throw new Error(`Unknown vector source ${source}`);
@@ -196,7 +196,7 @@ router.post('/purge', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
- const sources = ['transformers', 'openai'];
+ const sources = ['transformers', 'openai', 'palm'];
for (const source of sources) {
const index = await getIndex(collectionId, source, false);
diff --git a/src/palm-vectors.js b/src/makersuite-vectors.js
similarity index 65%
rename from src/palm-vectors.js
rename to src/makersuite-vectors.js
index 788b474cd..66d1a6fd8 100644
--- a/src/palm-vectors.js
+++ b/src/makersuite-vectors.js
@@ -6,15 +6,15 @@ const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
* @param {string} text - The text to get the vector for
* @returns {Promise} - The vector for the text
*/
-async function getPaLMVector(text) {
- const key = readSecret(SECRET_KEYS.PALM);
+async function getMakerSuiteVector(text) {
+ const key = readSecret(SECRET_KEYS.MAKERSUITE);
if (!key) {
- console.log('No PaLM key found');
- throw new Error('No PaLM key found');
+ console.log('No MakerSuite key found');
+ throw new Error('No MakerSuite key found');
}
- const response = await fetch(`https://generativelanguage.googleapis.com/v1beta2/models/embedding-gecko-001:embedText?key=${key}`, {
+ const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models/embedding-gecko-001:embedText?key=${key}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -26,8 +26,8 @@ async function getPaLMVector(text) {
if (!response.ok) {
const text = await response.text();
- console.log('PaLM request failed', response.statusText, text);
- throw new Error('PaLM request failed');
+ console.log('MakerSuite request failed', response.statusText, text);
+ throw new Error('MakerSuite request failed');
}
const data = await response.json();
@@ -39,5 +39,5 @@ async function getPaLMVector(text) {
}
module.exports = {
- getPaLMVector,
+ getMakerSuiteVector,
};