Fix continue on forced OR instruct. Display proper itemized prompt

This commit is contained in:
Cohee 2023-11-05 02:20:15 +02:00
parent 06d5675d6f
commit 88df8501b3
2 changed files with 66 additions and 9 deletions

View File

@ -93,6 +93,7 @@ import {
openai_messages_count,
chat_completion_sources,
getChatCompletionModel,
isOpenRouterWithInstruct,
} from "./scripts/openai.js";
import {
@ -811,6 +812,29 @@ export async function saveItemizedPrompts(chatId) {
}
}
/**
* Replaces the itemized prompt text for a message.
* @param {number} mesId Message ID to get itemized prompt for
* @param {string} promptText New raw prompt text
* @returns
*/
export async function replaceItemizedPromptText(mesId, promptText) {
if (!Array.isArray(itemizedPrompts)) {
itemizedPrompts = [];
}
const itemizedPrompt = itemizedPrompts.find(x => x.mesId === mesId);
if (!itemizedPrompt) {
return;
}
itemizedPrompt.rawPrompt = promptText;
}
/**
* Empties the itemized prompts array and caches.
*/
export async function clearItemizedPrompts() {
try {
await promptStorage.clear();
@ -2893,7 +2917,8 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
if (isContinue) {
// Coping mechanism for OAI spacing
if ((main_api === 'openai') && !cyclePrompt.endsWith(' ')) {
const isForceInstruct = isOpenRouterWithInstruct();
if (main_api === 'openai' && !isForceInstruct && !cyclePrompt.endsWith(' ')) {
cyclePrompt += ' ';
continue_mag += ' ';
}
@ -3290,8 +3315,15 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
thisPromptBits = additionalPromptStuff;
//console.log(thisPromptBits);
const itemizedIndex = itemizedPrompts.findIndex((item) => item.mesId === thisPromptBits['mesId']);
if (itemizedIndex !== -1) {
itemizedPrompts[itemizedIndex] = thisPromptBits;
}
else {
itemizedPrompts.push(thisPromptBits);
}
itemizedPrompts.push(thisPromptBits);
console.debug(`pushed prompt bits to itemizedPrompts array. Length is now: ${itemizedPrompts.length}`);
if (main_api == 'openai') {
@ -3510,7 +3542,7 @@ function unblockGeneration() {
$("#send_textarea").removeAttr('disabled');
}
function getNextMessageId(type) {
export function getNextMessageId(type) {
return type == 'swipe' ? Number(count_view_mes - 1) : Number(count_view_mes);
}
@ -9044,6 +9076,9 @@ jQuery(async function () {
registerDebugFunction('clearPrompts', 'Delete itemized prompts', 'Deletes all itemized prompts from the local storage.', async () => {
await clearItemizedPrompts();
await reloadCurrentChat();
toastr.info('Itemized prompts deleted.');
if (getCurrentChatId()) {
await reloadCurrentChat();
}
});
});

View File

@ -25,6 +25,9 @@ import {
event_types,
substituteParams,
MAX_INJECTION_DEPTH,
getStoppingStrings,
getNextMessageId,
replaceItemizedPromptText,
} from "../script.js";
import { groups, selected_group } from "./group-chats.js";
@ -367,14 +370,19 @@ function convertChatCompletionToInstruct(messages, type) {
}
const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue';
const promptName = isImpersonate ? name1 : name2;
const promptLine = formatInstructModePrompt(promptName, isImpersonate, '', name1, name2).trimStart();
const promptLine = isContinue ? '' : formatInstructModePrompt(promptName, isImpersonate, '', name1, name2).trimStart();
const prompt = [systemPromptText, examplesText, chatMessagesText, promptLine]
let prompt = [systemPromptText, examplesText, chatMessagesText, promptLine]
.filter(x => x)
.map(x => x.endsWith('\n') ? x : `${x}\n`)
.join('');
if (isContinue) {
prompt = prompt.replace(/\n$/, '');
}
return prompt;
}
@ -578,6 +586,10 @@ function populationInjectionPrompts(prompts) {
openai_msgs = openai_msgs.reverse();
}
export function isOpenRouterWithInstruct() {
return oai_settings.chat_completion_source === chat_completion_sources.OPENROUTER && oai_settings.openrouter_force_instruct && power_user.instruct.enabled;
}
/**
* Populates the chat history of the conversation.
*
@ -604,7 +616,8 @@ function populateChatHistory(prompts, chatCompletion, type = null, cyclePrompt =
// Reserve budget for continue nudge
let continueMessage = null;
if (type === 'continue' && cyclePrompt) {
const instruct = isOpenRouterWithInstruct();
if (type === 'continue' && cyclePrompt && !instruct) {
const continuePrompt = new Prompt({
identifier: 'continueNudge',
role: 'system',
@ -1249,7 +1262,7 @@ function saveModelList(data) {
}
}
async function sendAltScaleRequest(openai_msgs_tosend, logit_bias, signal) {
async function sendAltScaleRequest(openai_msgs_tosend, logit_bias, signal, type) {
const generate_url = '/generate_altscale';
let firstSysMsgs = []
@ -1269,6 +1282,8 @@ async function sendAltScaleRequest(openai_msgs_tosend, logit_bias, signal) {
}, "");
openai_msgs_tosend = substituteParams(joinedSubsequentMsgs);
const messageId = getNextMessageId(type);
replaceItemizedPromptText(messageId, openai_msgs_tosend);
const generate_data = {
sysprompt: joinedSysMsgs,
@ -1304,6 +1319,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
openai_msgs_tosend = openai_msgs_tosend.filter(msg => msg && typeof msg === 'object');
let logit_bias = {};
const messageId = getNextMessageId(type);
const isClaude = oai_settings.chat_completion_source == chat_completion_sources.CLAUDE;
const isOpenRouter = oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER;
const isScale = oai_settings.chat_completion_source == chat_completion_sources.SCALE;
@ -1317,6 +1333,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
if (isTextCompletion && isOpenRouter) {
openai_msgs_tosend = convertChatCompletionToInstruct(openai_msgs_tosend, type);
replaceItemizedPromptText(messageId, openai_msgs_tosend);
}
if (isAI21 || isPalm) {
@ -1325,6 +1342,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
return acc + (prefix ? (selected_group ? "\n" : prefix + " ") : "") + obj.content + "\n";
}, "");
openai_msgs_tosend = substituteParams(joinedMsgs) + (isImpersonate ? `${name1}:` : `${name2}:`);
replaceItemizedPromptText(messageId, openai_msgs_tosend);
}
// If we're using the window.ai extension, use that instead
@ -1343,7 +1361,7 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
}
if (isScale && oai_settings.use_alt_scale) {
return sendAltScaleRequest(openai_msgs_tosend, logit_bias, signal)
return sendAltScaleRequest(openai_msgs_tosend, logit_bias, signal, type);
}
const model = getChatCompletionModel();
@ -1382,6 +1400,10 @@ async function sendOpenAIRequest(type, openai_msgs_tosend, signal) {
generate_data['use_openrouter'] = true;
generate_data['top_k'] = Number(oai_settings.top_k_openai);
generate_data['use_fallback'] = oai_settings.openrouter_use_fallback;
if (isTextCompletion) {
generate_data['stop'] = getStoppingStrings(isImpersonate);
}
}
if (isScale) {