Add raw generate function
This commit is contained in:
parent
abb78d1d6b
commit
c4fbc8373d
102
public/script.js
102
public/script.js
|
@ -2350,6 +2350,84 @@ class StreamingProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a message using the provided prompt.
|
||||
* @param {string} prompt Prompt to generate a message from
|
||||
* @param {string} api API to use. Main API is used if not specified.
|
||||
*/
|
||||
export async function generateRaw(prompt, api) {
|
||||
if (!api) {
|
||||
api = main_api;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const isInstruct = power_user.instruct.enabled && main_api !== 'openai' && main_api !== 'novel';
|
||||
|
||||
prompt = substituteParams(prompt);
|
||||
prompt = api == 'novel' ? adjustNovelInstructionPrompt(prompt) : prompt;
|
||||
prompt = isInstruct ? formatInstructModeChat(name1, prompt, false, true, '', name1, name2, false) : prompt;
|
||||
prompt = isInstruct ? (prompt + formatInstructModePrompt(name2, false, '', name1, name2)) : (prompt + '\n');
|
||||
|
||||
let generateData = {};
|
||||
|
||||
switch (api) {
|
||||
case 'kobold':
|
||||
case 'koboldhorde':
|
||||
if (preset_settings === 'gui') {
|
||||
generateData = { prompt: prompt, gui_settings: true, max_length: amount_gen, max_context_length: max_context, };
|
||||
} else {
|
||||
const koboldSettings = koboldai_settings[koboldai_setting_names[preset_settings]];
|
||||
generateData = getKoboldGenerationData(prompt, koboldSettings, amount_gen, max_context, false, 'quiet');
|
||||
}
|
||||
break;
|
||||
case 'novel':
|
||||
const novelSettings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
|
||||
generateData = getNovelGenerationData(prompt, novelSettings, amount_gen, false, null);
|
||||
break;
|
||||
case 'textgenerationwebui':
|
||||
generateData = getTextGenGenerationData(prompt, amount_gen, false, null);
|
||||
break;
|
||||
case 'openai':
|
||||
generateData = [{ role: 'user', content: prompt.trim() }];
|
||||
}
|
||||
|
||||
let data = {};
|
||||
|
||||
if (api == 'koboldhorde') {
|
||||
data = await generateHorde(prompt, generateData, abortController.signal, false);
|
||||
} else if (api == 'openai') {
|
||||
data = await sendOpenAIRequest('quiet', generateData, abortController.signal);
|
||||
} else {
|
||||
const generateUrl = getGenerateUrl(api);
|
||||
const response = await fetch(generateUrl, {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
cache: 'no-cache',
|
||||
body: JSON.stringify(generateData),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw error;
|
||||
}
|
||||
|
||||
data = await response.json();
|
||||
}
|
||||
|
||||
if (data.error) {
|
||||
throw new Error(data.error);
|
||||
}
|
||||
|
||||
const message = cleanUpMessage(extractMessageFromData(data), false, false, true);
|
||||
|
||||
if (!message) {
|
||||
throw new Error('No message generated');
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
async function Generate(type, { automatic_trigger, force_name2, resolve, reject, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal } = {}, dryRun = false) {
|
||||
console.log('Generate entered');
|
||||
setGenerationProgress(0);
|
||||
|
@ -3051,8 +3129,6 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
|
|||
const cfgValues = cfgGuidanceScale && cfgGuidanceScale?.value !== 1 ? ({ guidanceScale: cfgGuidanceScale, negativePrompt: negativePrompt }) : null;
|
||||
|
||||
let this_amount_gen = Number(amount_gen); // how many tokens the AI will be requested to generate
|
||||
let this_settings = koboldai_settings[koboldai_setting_names[preset_settings]];
|
||||
|
||||
let thisPromptBits = [];
|
||||
|
||||
// TODO: Make this a switch
|
||||
|
@ -3071,14 +3147,13 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
|
|||
};
|
||||
|
||||
if (preset_settings != 'gui') {
|
||||
const this_settings = koboldai_settings[koboldai_setting_names[preset_settings]];
|
||||
const maxContext = (adjustedParams && horde_settings.auto_adjust_context_length) ? adjustedParams.maxContextLength : max_context;
|
||||
generate_data = getKoboldGenerationData(finalPrompt, this_settings, this_amount_gen, maxContext, isImpersonate, type);
|
||||
}
|
||||
}
|
||||
else if (main_api == 'textgenerationwebui') {
|
||||
generate_data = getTextGenGenerationData(finalPrompt, this_amount_gen, isImpersonate, cfgValues);
|
||||
generate_data.use_mancer = isMancer();
|
||||
generate_data.use_aphrodite = isAphrodite();
|
||||
}
|
||||
else if (main_api == 'novel') {
|
||||
const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
|
||||
|
@ -3119,7 +3194,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
|
|||
console.log(generate_data.prompt);
|
||||
}
|
||||
|
||||
let generate_url = getGenerateUrl();
|
||||
let generate_url = getGenerateUrl(main_api);
|
||||
console.debug('rungenerate calling API');
|
||||
|
||||
showStopButton();
|
||||
|
@ -3169,7 +3244,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
|
|||
}
|
||||
}
|
||||
else if (main_api == 'koboldhorde') {
|
||||
generateHorde(finalPrompt, generate_data, abortController.signal).then(onSuccess).catch(onError);
|
||||
generateHorde(finalPrompt, generate_data, abortController.signal, true).then(onSuccess).catch(onError);
|
||||
}
|
||||
else if (main_api == 'textgenerationwebui' && isStreamingEnabled() && type !== 'quiet') {
|
||||
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal);
|
||||
|
@ -3817,13 +3892,13 @@ function setInContextMessages(lastmsg, type) {
|
|||
}
|
||||
}
|
||||
|
||||
function getGenerateUrl() {
|
||||
function getGenerateUrl(api) {
|
||||
let generate_url = '';
|
||||
if (main_api == 'kobold') {
|
||||
if (api == 'kobold') {
|
||||
generate_url = '/generate';
|
||||
} else if (main_api == 'textgenerationwebui') {
|
||||
} else if (api == 'textgenerationwebui') {
|
||||
generate_url = '/generate_textgenerationwebui';
|
||||
} else if (main_api == 'novel') {
|
||||
} else if (api == 'novel') {
|
||||
generate_url = '/api/novelai/generate';
|
||||
}
|
||||
return generate_url;
|
||||
|
@ -8988,4 +9063,11 @@ jQuery(async function () {
|
|||
await saveChatConditional();
|
||||
await reloadCurrentChat();
|
||||
});
|
||||
|
||||
registerDebugFunction('generationTest', 'Send a generation request', 'Generates text using the currently selected API.', async () => {
|
||||
const text = prompt('Input text:', 'Hello');
|
||||
toastr.info('Working on it...');
|
||||
const message = await generateRaw(text, null);
|
||||
alert(message);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -92,7 +92,7 @@ async function adjustHordeGenerationParams(max_context_length, max_length) {
|
|||
return { maxContextLength, maxLength };
|
||||
}
|
||||
|
||||
async function generateHorde(prompt, params, signal) {
|
||||
async function generateHorde(prompt, params, signal, reportProgress) {
|
||||
validateHordeModel();
|
||||
delete params.prompt;
|
||||
|
||||
|
@ -164,7 +164,7 @@ async function generateHorde(prompt, params, signal) {
|
|||
}
|
||||
|
||||
if (statusCheckJson.done && Array.isArray(statusCheckJson.generations) && statusCheckJson.generations.length) {
|
||||
setGenerationProgress(100);
|
||||
reportProgress && setGenerationProgress(100);
|
||||
const generatedText = statusCheckJson.generations[0].text;
|
||||
const WorkerName = statusCheckJson.generations[0].worker_name;
|
||||
const WorkerModel = statusCheckJson.generations[0].model;
|
||||
|
@ -174,12 +174,12 @@ async function generateHorde(prompt, params, signal) {
|
|||
}
|
||||
else if (!queue_position_first) {
|
||||
queue_position_first = statusCheckJson.queue_position;
|
||||
setGenerationProgress(0);
|
||||
reportProgress && setGenerationProgress(0);
|
||||
}
|
||||
else if (statusCheckJson.queue_position >= 0) {
|
||||
let queue_position = statusCheckJson.queue_position;
|
||||
const progress = Math.round(100 - (queue_position / queue_position_first * 100));
|
||||
setGenerationProgress(progress);
|
||||
reportProgress && setGenerationProgress(progress);
|
||||
}
|
||||
|
||||
await delay(CHECK_INTERVAL);
|
||||
|
|
|
@ -437,5 +437,7 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
|
|||
'mirostat_eta': textgenerationwebui_settings.mirostat_eta,
|
||||
'grammar_string': textgenerationwebui_settings.grammar_string,
|
||||
'custom_token_bans': getCustomTokenBans(),
|
||||
'use_mancer': isMancer(),
|
||||
'use_aphrodite': isAphrodite(),
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue