Add direct OpenRouter connection and PaLM models to Window selection

This commit is contained in:
Cohee
2023-06-30 00:32:52 +03:00
parent 757e9b672a
commit f532192726
5 changed files with 137 additions and 31 deletions

View File

@ -86,7 +86,8 @@ const gpt3_16k_max = 16383;
const gpt4_max = 8191;
const gpt_neox_max = 2048;
const gpt4_32k_max = 32767;
const claude_max = 7500;
const claude_max = 8000; // We have a proper tokenizer, so theoretically could be larger (up to 9k)
const palm2_max = 8000; // The real context window is 8192, spare some for padding due to using turbo tokenizer
const claude_100k_max = 99000;
const unlocked_max = 100 * 1024;
const oai_max_temp = 2.0;
@ -132,6 +133,7 @@ const default_settings = {
legacy_streaming: false,
chat_completion_source: chat_completion_sources.OPENAI,
max_context_unlocked: false,
use_openrouter: false,
};
const oai_settings = {
@ -165,6 +167,7 @@ const oai_settings = {
legacy_streaming: false,
chat_completion_source: chat_completion_sources.OPENAI,
max_context_unlocked: false,
use_openrouter: false,
};
let openai_setting_names;
@ -568,8 +571,8 @@ async function sendWindowAIRequest(openai_msgs_tosend, signal, stream) {
const currentModel = await window.ai.getCurrentModel();
let temperature = parseFloat(oai_settings.temp_openai);
if (currentModel.includes('claude') && temperature > claude_max_temp) {
console.warn(`Claude model only supports temperature up to ${claude_max_temp}. Clamping ${temperature} to ${claude_max_temp}.`);
if ((currentModel.includes('claude') || currentModel.includes('palm-2')) && temperature > claude_max_temp) {
console.warn(`Claude and PaLM models only supports temperature up to ${claude_max_temp}. Clamping ${temperature} to ${claude_max_temp}.`);
temperature = claude_max_temp;
}
@ -649,6 +652,19 @@ async function sendWindowAIRequest(openai_msgs_tosend, signal, stream) {
}
}
function getChatCompletionModel() {
switch (oai_settings.chat_completion_source) {
case chat_completion_sources.CLAUDE:
return oai_settings.claude_model;
case chat_completion_sources.OPENAI:
return oai_settings.openai_model;
case chat_completion_sources.WINDOWAI:
return oai_settings.windowai_model;
default:
throw new Error(`Unknown chat completion source: ${oai_settings.chat_completion_source}`);
}
}
async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
// Provide default abort signal
if (!signal) {
@ -661,23 +677,24 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
let logit_bias = {};
const isClaude = oai_settings.chat_completion_source == chat_completion_sources.CLAUDE;
const isOpenRouter = oai_settings.use_openrouter && oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI;
const stream = type !== 'quiet' && oai_settings.stream_openai;
// If we're using the window.ai extension, use that instead
// Doesn't support logit bias yet
if (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI) {
if (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI && !oai_settings.use_openrouter) {
return sendWindowAIRequest(openai_msgs_tosend, signal, stream);
}
if (oai_settings.bias_preset_selected
&& !isClaude // Claude doesn't support logit bias
&& oai_settings.chat_completion_source == chat_completion_sources.OPENAI
&& Array.isArray(oai_settings.bias_presets[oai_settings.bias_preset_selected])
&& oai_settings.bias_presets[oai_settings.bias_preset_selected].length) {
logit_bias = biasCache || await calculateLogitBias();
biasCache = logit_bias;
}
const model = isClaude ? oai_settings.claude_model : oai_settings.openai_model;
const model = getChatCompletionModel();
const generate_data = {
"messages": openai_msgs_tosend,
"model": model,
@ -691,6 +708,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
"reverse_proxy": oai_settings.reverse_proxy,
"logit_bias": logit_bias,
"use_claude": isClaude,
"use_openrouter": isOpenRouter,
};
const generate_url = '/generate_openai';
@ -767,7 +785,7 @@ function getStreamingReply(getMessage, data) {
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
getMessage = data.completion || "";
} else {
getMessage += data.choices[0]["delta"]["content"] || "";
getMessage += data.choices[0]?.delta?.content || data.choices[0]?.message?.content || "";
}
return getMessage;
}
@ -979,6 +997,7 @@ function loadOpenAISettings(data, settings) {
oai_settings.claude_model = settings.claude_model ?? default_settings.claude_model;
oai_settings.windowai_model = settings.windowai_model ?? default_settings.windowai_model;
oai_settings.chat_completion_source = settings.chat_completion_source ?? default_settings.chat_completion_source;
oai_settings.use_openrouter = settings.use_openrouter ?? default_settings.use_openrouter;
if (settings.nsfw_toggle !== undefined) oai_settings.nsfw_toggle = !!settings.nsfw_toggle;
if (settings.keep_example_dialogue !== undefined) oai_settings.keep_example_dialogue = !!settings.keep_example_dialogue;
@ -1055,11 +1074,12 @@ function loadOpenAISettings(data, settings) {
$('#chat_completion_source').val(oai_settings.chat_completion_source).trigger('change');
$('#oai_max_context_unlocked').prop('checked', oai_settings.max_context_unlocked);
$('#use_openrouter').prop('checked', oai_settings.use_openrouter);
}
async function getStatusOpen() {
if (is_get_status_openai) {
if (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI) {
if (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI && !oai_settings.use_openrouter) {
let status;
if ('ai' in window) {
@ -1082,6 +1102,7 @@ async function getStatusOpen() {
let data = {
reverse_proxy: oai_settings.reverse_proxy,
use_openrouter: oai_settings.use_openrouter && oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI,
};
return jQuery.ajax({
@ -1089,7 +1110,7 @@ async function getStatusOpen() {
url: '/getstatus_openai', //
data: JSON.stringify(data),
beforeSend: function () {
if (oai_settings.reverse_proxy) {
if (oai_settings.reverse_proxy && !data.use_openrouter) {
validateReverseProxy();
}
},
@ -1157,6 +1178,7 @@ async function saveOpenAIPreset(name, settings) {
openai_model: settings.openai_model,
claude_model: settings.claude_model,
windowai_model: settings.windowai_model,
use_openrouter: settings.use_openrouter,
temperature: settings.temp_openai,
frequency_penalty: settings.freq_pen_openai,
presence_penalty: settings.pres_pen_openai,
@ -1530,6 +1552,7 @@ function onSettingsPresetChange() {
nsfw_avoidance_prompt: ['#nsfw_avoidance_prompt_textarea', 'nsfw_avoidance_prompt', false],
wi_format: ['#wi_format_textarea', 'wi_format', false],
stream_openai: ['#stream_toggle', 'stream_openai', true],
use_openrouter: ['#use_openrouter', 'use_openrouter', true],
};
for (const [key, [selector, setting, isCheckbox]] of Object.entries(settingsToUpdate)) {
@ -1605,6 +1628,9 @@ function onModelChange() {
else if (value.includes('gpt-4')) {
$('#openai_max_context').attr('max', gpt4_max);
}
else if (value.includes('palm-2')) {
$('#openai_max_context').attr('max', palm2_max);
}
else if (value.includes('GPT-NeoXT')) {
$('#openai_max_context').attr('max', gpt_neox_max);
}
@ -1616,7 +1642,7 @@ function onModelChange() {
oai_settings.openai_max_context = Math.min(Number($('#openai_max_context').attr('max')), oai_settings.openai_max_context);
$('#openai_max_context').val(oai_settings.openai_max_context).trigger('input');
if (value.includes('claude')) {
if (value.includes('claude') || value.includes('palm-2')) {
oai_settings.temp_openai = Math.min(claude_max_temp, oai_settings.temp_openai);
$('#temp_openai').attr('max', claude_max_temp).val(oai_settings.temp_openai).trigger('input');
}
@ -1682,6 +1708,17 @@ async function onConnectButtonClick(e) {
if (oai_settings.chat_completion_source == chat_completion_sources.WINDOWAI) {
is_get_status_openai = true;
is_api_button_press_openai = true;
const api_key_openrouter = $('#api_key_openrouter').val().trim();
if (api_key_openrouter.length) {
await writeSecret(SECRET_KEYS.OPENROUTER, api_key_openrouter);
}
if (oai_settings.use_openrouter && !secret_state[SECRET_KEYS.OPENROUTER]) {
console.log('No secret key saved for OpenRouter');
return;
}
return await getStatusOpen();
}
@ -1955,6 +1992,12 @@ $(document).ready(function () {
saveSettingsDebounced();
});
$('#use_openrouter').on('input', function () {
oai_settings.use_openrouter = !!$(this).prop('checked');
reconnectOpenAi();
saveSettingsDebounced();
});
$("#api_button_openai").on("click", onConnectButtonClick);
$("#openai_reverse_proxy").on("input", onReverseProxyInput);
$("#model_openai_select").on("change", onModelChange);