Add raw generate function

This commit is contained in:
Cohee 2023-10-11 17:56:52 +03:00
parent abb78d1d6b
commit c4fbc8373d
3 changed files with 98 additions and 14 deletions

View File

@ -2350,6 +2350,84 @@ class StreamingProcessor {
}
}
/**
* Generates a message using the provided prompt.
* @param {string} prompt Prompt to generate a message from
* @param {string} api API to use. Main API is used if not specified.
*/
export async function generateRaw(prompt, api) {
if (!api) {
api = main_api;
}
const abortController = new AbortController();
const isInstruct = power_user.instruct.enabled && main_api !== 'openai' && main_api !== 'novel';
prompt = substituteParams(prompt);
prompt = api == 'novel' ? adjustNovelInstructionPrompt(prompt) : prompt;
prompt = isInstruct ? formatInstructModeChat(name1, prompt, false, true, '', name1, name2, false) : prompt;
prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2)) : (prompt + '\n');
let generateData = {};
switch (api) {
case 'kobold':
case 'koboldhorde':
if (preset_settings === 'gui') {
generateData = { prompt: prompt, gui_settings: true, max_length: amount_gen, max_context_length: max_context, };
} else {
const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]];
generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, false, 'quiet');
}
break;
case 'novel':
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, null);
break;
case 'textgenerationwebui':
generateData = getTextGenGenerationData(prompt, amount_gen, false, null);
break;
case 'openai':
generateData = [{ role: 'user', content: prompt.trim() }];
}
let data = {};
if (api == 'koboldhorde') {
data = await generateHorde(prompt, generateData, abortController.signal, false);
} else if (api == 'openai') {
data = await sendOpenAIRequest('quiet', generateData, abortController.signal);
} else {
const generateUrl = getGenerateUrl(api);
const response = await fetch(generateUrl, {
method: 'POST',
headers: getRequestHeaders(),
cache: 'no-cache',
body: JSON.stringify(generateData),
signal: abortController.signal,
});
if (!response.ok) {
const error = await response.json();
throw error;
}
data = await response.json();
}
if (data.error) {
throw new Error(data.error);
}
const message = cleanUpMessage(extractMessageFromData(data), false, false, true);
if (!message) {
throw new Error('No message generated');
}
return message;
}
async function Generate(type, { automatic_trigger, force_name2, resolve, reject, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal } = {}, dryRun = false) {
console.log('Generate entered');
setGenerationProgress(0);
@ -3051,8 +3129,6 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
const cfgValues = cfgGuidanceScale && cfgGuidanceScale?.value !== 1 ? ({ guidanceScale: cfgGuidanceScale, negativePrompt: negativePrompt }) : null;
let this_amount_gen = Number(amount_gen); // how many tokens the AI will be requested to generate
let this_settings = koboldai_settings[koboldai_setting_names[preset_settings]];
let thisPromptBits = [];
// TODO: Make this a switch
@ -3071,14 +3147,13 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
};
if (preset_settings != 'gui') {
const this_settings = koboldai_settings[koboldai_setting_names[preset_settings]];
const maxContext = (adjustedParams && horde_settings.auto_adjust_context_length) ? adjustedParams.maxContextLength : max_context;
generate_data = getKoboldGenerationData(finalPrompt, this_settings, this_amount_gen, maxContext, isImpersonate, type);
}
}
else if (main_api == 'textgenerationwebui') {
generate_data = getTextGenGenerationData(finalPrompt, this_amount_gen, isImpersonate, cfgValues);
generate_data.use_mancer = isMancer();
generate_data.use_aphrodite = isAphrodite();
}
else if (main_api == 'novel') {
const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
@ -3119,7 +3194,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
console.log(generate_data.prompt);
}
let generate_url = getGenerateUrl();
let generate_url = getGenerateUrl(main_api);
console.debug('rungenerate calling API');
showStopButton();
@ -3169,7 +3244,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
}
}
else if (main_api == 'koboldhorde') {
generateHorde(finalPrompt, generate_data, abortController.signal).then(onSuccess).catch(onError);
generateHorde(finalPrompt, generate_data, abortController.signal, true).then(onSuccess).catch(onError);
}
else if (main_api == 'textgenerationwebui' && isStreamingEnabled() && type !== 'quiet') {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal);
@ -3817,13 +3892,13 @@ function setInContextMessages(lastmsg, type) {
}
}
function getGenerateUrl() {
function getGenerateUrl(api) {
let generate_url = '';
if (main_api == 'kobold') {
if (api == 'kobold') {
generate_url = '/generate';
} else if (main_api == 'textgenerationwebui') {
} else if (api == 'textgenerationwebui') {
generate_url = '/generate_textgenerationwebui';
} else if (main_api == 'novel') {
} else if (api == 'novel') {
generate_url = '/api/novelai/generate';
}
return generate_url;
@ -8988,4 +9063,11 @@ jQuery(async function () {
await saveChatConditional();
await reloadCurrentChat();
});
registerDebugFunction('generationTest', 'Send a generation request', 'Generates text using the currently selected API.', async () => {
const text = prompt('Input text:', 'Hello');
toastr.info('Working on it...');
const message = await generateRaw(text, null);
alert(message);
});
});

View File

@ -92,7 +92,7 @@ async function adjustHordeGenerationParams(max_context_length, max_length) {
return { maxContextLength, maxLength };
}
async function generateHorde(prompt, params, signal) {
async function generateHorde(prompt, params, signal, reportProgress) {
validateHordeModel();
delete params.prompt;
@ -164,7 +164,7 @@ async function generateHorde(prompt, params, signal) {
}
if (statusCheckJson.done && Array.isArray(statusCheckJson.generations) && statusCheckJson.generations.length) {
setGenerationProgress(100);
reportProgress && setGenerationProgress(100);
const generatedText = statusCheckJson.generations[0].text;
const WorkerName = statusCheckJson.generations[0].worker_name;
const WorkerModel = statusCheckJson.generations[0].model;
@ -174,12 +174,12 @@ async function generateHorde(prompt, params, signal) {
}
else if (!queue_position_first) {
queue_position_first = statusCheckJson.queue_position;
setGenerationProgress(0);
reportProgress && setGenerationProgress(0);
}
else if (statusCheckJson.queue_position >= 0) {
let queue_position = statusCheckJson.queue_position;
const progress = Math.round(100 - (queue_position / queue_position_first * 100));
setGenerationProgress(progress);
reportProgress && setGenerationProgress(progress);
}
await delay(CHECK_INTERVAL);

View File

@ -437,5 +437,7 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
'mirostat_eta': textgenerationwebui_settings.mirostat_eta,
'grammar_string': textgenerationwebui_settings.grammar_string,
'custom_token_bans': getCustomTokenBans(),
'use_mancer': isMancer(),
'use_aphrodite': isAphrodite(),
};
}