Add Sd.next source

This commit is contained in:
Cohee 2023-10-07 18:30:06 +03:00
parent 27ce0b5eb7
commit 1dd6fa4b6a
3 changed files with 216 additions and 19 deletions

View File

@ -35,6 +35,7 @@ const sources = {
horde: 'horde',
auto: 'auto',
novel: 'novel',
vlad: 'vlad',
}
const generationMode = {
@ -165,6 +166,9 @@ const defaultSettings = {
auto_url: 'http://localhost:7860',
auto_auth: '',
vlad_url: 'http://localhost:7860',
vlad_auth: '',
hr_upscaler: 'Latent',
hr_scale: 2.0,
hr_scale_min: 1.0,
@ -187,12 +191,21 @@ const defaultSettings = {
novel_anlas_guard: false,
}
const getAutoRequestBody = () => ({ url: extension_settings.sd.auto_url, auth: extension_settings.sd.auto_auth });
function getSdRequestBody() {
switch (extension_settings.sd.source) {
case sources.vlad:
return { url: extension_settings.sd.vlad_url, auth: extension_settings.sd.vlad_auth };
case sources.auto:
return { url: extension_settings.sd.auto_url, auth: extension_settings.sd.auto_auth };
default:
throw new Error('Invalid SD source.');
}
}
function toggleSourceControls() {
$('.sd_settings [data-sd-source]').each(function () {
const source = $(this).data('sd-source');
$(this).toggle(source === extension_settings.sd.source);
const source = $(this).data('sd-source').split(',');
$(this).toggle(source.includes(extension_settings.sd.source));
});
}
@ -244,6 +257,8 @@ async function loadSettings() {
$('#sd_refine_mode').prop('checked', extension_settings.sd.refine_mode);
$('#sd_auto_url').val(extension_settings.sd.auto_url);
$('#sd_auto_auth').val(extension_settings.sd.auto_auth);
$('#sd_vlad_url').val(extension_settings.sd.vlad_url);
$('#sd_vlad_auth').val(extension_settings.sd.vlad_auth);
toggleSourceControls();
addPromptTemplates();
@ -285,7 +300,7 @@ function addPromptTemplates() {
async function refinePrompt(prompt) {
if (extension_settings.sd.refine_mode) {
const refinedPrompt = await callPopup('<h3>Review and edit the prompt:</h3>Press "Cancel" to abort the image generation.', 'input', prompt, { rows: 5, okButton: 'Generate' });
const refinedPrompt = await callPopup('<h3>Review and edit the prompt:</h3>Press "Cancel" to abort the image generation.', 'input', prompt.trim(), { rows: 5, okButton: 'Generate' });
if (refinedPrompt) {
return refinedPrompt;
@ -316,7 +331,7 @@ function onCharacterPromptInput() {
}
function getCharacterPrefix() {
if (selected_group) {
if (!this_chid || selected_group) {
return '';
}
@ -454,6 +469,16 @@ function onAutoAuthInput() {
saveSettingsDebounced();
}
function onVladUrlInput() {
extension_settings.sd.vlad_url = $('#sd_vlad_url').val();
saveSettingsDebounced();
}
function onVladAuthInput() {
extension_settings.sd.vlad_auth = $('#sd_vlad_auth').val();
saveSettingsDebounced();
}
function onHrUpscalerChange() {
extension_settings.sd.hr_upscaler = $('#sd_hr_upscaler').find(':selected').val();
saveSettingsDebounced();
@ -486,7 +511,7 @@ async function validateAutoUrl() {
const result = await fetch('/api/sd/ping', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getAutoRequestBody()),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
@ -501,6 +526,30 @@ async function validateAutoUrl() {
}
}
async function validateVladUrl() {
try {
if (!extension_settings.sd.vlad_url) {
throw new Error('URL is not set.');
}
const result = await fetch('/api/sd/ping', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
throw new Error('SD.Next returned an error.');
}
await loadSamplers();
await loadModels();
toastr.success('SD.Next API connected.');
} catch (error) {
toastr.error(`Could not validate SD.Next API: ${error.message}`);
}
}
async function onModelChange() {
extension_settings.sd.model = $('#sd_model').find(':selected').val();
saveSettingsDebounced();
@ -515,7 +564,7 @@ async function onModelChange() {
if (extension_settings.sd.source === sources.extras) {
await updateExtrasRemoteModel();
}
if (extension_settings.sd.source === sources.auto) {
if (extension_settings.sd.source === sources.auto || extension_settings.sd.source === sources.vlad) {
await updateAutoRemoteModel();
}
toastr.success('Model successfully loaded!', 'Stable Diffusion');
@ -526,7 +575,7 @@ async function getAutoRemoteModel() {
const result = await fetch('/api/sd/get-model', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getAutoRequestBody()),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
@ -546,7 +595,7 @@ async function getAutoRemoteUpscalers() {
const result = await fetch('/api/sd/upscalers', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getAutoRequestBody()),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
@ -561,12 +610,32 @@ async function getAutoRemoteUpscalers() {
}
}
async function getVladRemoteUpscalers() {
try {
const result = await fetch('/api/sd-next/upscalers', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
throw new Error('SD.Next returned an error.');
}
const data = await result.json();
return data;
} catch (error) {
console.error(error);
return [extension_settings.sd.hr_upscaler];
}
}
async function updateAutoRemoteModel() {
try {
const result = await fetch('/api/sd/set-model', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ ...getAutoRequestBody(), model: extension_settings.sd.model }),
body: JSON.stringify({ ...getSdRequestBody(), model: extension_settings.sd.model }),
});
if (!result.ok) {
@ -610,6 +679,9 @@ async function loadSamplers() {
case sources.novel:
samplers = await loadNovelSamplers();
break;
case sources.vlad:
samplers = await loadVladSamplers();
break;
}
for (const sampler of samplers) {
@ -661,7 +733,7 @@ async function loadAutoSamplers() {
const result = await fetch('/api/sd/samplers', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getAutoRequestBody()),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
@ -675,6 +747,29 @@ async function loadAutoSamplers() {
}
}
async function loadVladSamplers() {
if (!extension_settings.sd.vlad_url) {
return [];
}
try {
const result = await fetch('/api/sd/samplers', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
throw new Error('SD.Next returned an error.');
}
const data = await result.json();
return data;
} catch (error) {
return [];
}
}
async function loadNovelSamplers() {
if (!secret_state[SECRET_KEYS.NOVEL]) {
console.debug('NovelAI API key is not set.');
@ -709,6 +804,9 @@ async function loadModels() {
case sources.novel:
models = await loadNovelModels();
break;
case sources.vlad:
models = await loadVladModels();
break;
}
for (const model of models) {
@ -778,7 +876,7 @@ async function loadAutoModels() {
const result = await fetch('/api/sd/models', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getAutoRequestBody()),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
@ -806,6 +904,49 @@ async function loadAutoModels() {
}
}
async function loadVladModels() {
if (!extension_settings.sd.vlad_url) {
return [];
}
try {
const currentModel = await getAutoRemoteModel();
if (currentModel) {
extension_settings.sd.model = currentModel;
}
const result = await fetch('/api/sd/models', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify(getSdRequestBody()),
});
if (!result.ok) {
throw new Error('SD WebUI returned an error.');
}
const upscalers = await getVladRemoteUpscalers();
if (Array.isArray(upscalers) && upscalers.length > 0) {
$('#sd_hr_upscaler').empty();
for (const upscaler of upscalers) {
const option = document.createElement('option');
option.innerText = upscaler;
option.value = upscaler;
option.selected = upscaler === extension_settings.sd.hr_upscaler;
$('#sd_hr_upscaler').append(option);
}
}
const data = await result.json();
return data;
} catch (error) {
return [];
}
}
async function loadNovelModels() {
if (!secret_state[SECRET_KEYS.NOVEL]) {
console.debug('NovelAI API key is not set.');
@ -913,7 +1054,7 @@ async function generatePicture(_, trigger, message, callback) {
// if context.characterId is not null, then we get context.characters[context.characterId].avatar, else we get groupId and context.groups[groupId].id
// sadly, groups is not an array, but is a dict with keys being index numbers, so we have to filter it
const characterName = context.characterId ? context.characters[context.characterId].name : context.groups[Object.keys(context.groups).filter(x => context.groups[x].id === context.groupId)[0]].id.toString();
const characterName = context.characterId ? context.characters[context.characterId].name : context.groups[Object.keys(context.groups).filter(x => context.groups[x].id === context.groupId)[0]]?.id?.toString();
const prevSDHeight = extension_settings.sd.height;
const prevSDWidth = extension_settings.sd.width;
@ -988,7 +1129,7 @@ async function getPrompt(generationType, message, trigger, quiet_prompt) {
}
async function generatePrompt(quiet_prompt) {
const reply = await generateQuietPrompt(quiet_prompt);
const reply = await generateQuietPrompt(quiet_prompt, false);
return processReply(reply);
}
@ -1010,6 +1151,9 @@ async function sendGenerationRequest(generationType, prompt, characterName = nul
case sources.horde:
result = await generateHordeImage(prefixedPrompt);
break;
case sources.vlad:
result = await generateAutoImage(prefixedPrompt);
break;
case sources.auto:
result = await generateAutoImage(prefixedPrompt);
break;
@ -1121,7 +1265,7 @@ async function generateAutoImage(prompt) {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({
...getAutoRequestBody(),
...getSdRequestBody(),
prompt: prompt,
negative_prompt: extension_settings.sd.negative_prompt,
sampler_name: extension_settings.sd.sampler,
@ -1325,6 +1469,8 @@ function isValidState() {
return true;
case sources.auto:
return !!extension_settings.sd.auto_url;
case sources.vlad:
return !!extension_settings.sd.vlad_url;
case sources.novel:
return secret_state[SECRET_KEYS.NOVEL];
}
@ -1357,7 +1503,7 @@ async function sdMessageButton(e) {
const message_id = $mes.attr('mesid');
const message = context.chat[message_id];
const characterName = message?.name || context.name2;
const characterFileName = context.characterId ? context.characters[context.characterId].name : context.groups[Object.keys(context.groups).filter(x => context.groups[x].id === context.groupId)[0]].id.toString();
const characterFileName = context.characterId ? context.characters[context.characterId].name : context.groups[Object.keys(context.groups).filter(x => context.groups[x].id === context.groupId)[0]]?.id?.toString();
const messageText = message?.mes;
const hasSavedImage = message?.extra?.image && message?.extra?.title;
@ -1445,6 +1591,9 @@ jQuery(async () => {
$('#sd_auto_validate').on('click', validateAutoUrl);
$('#sd_auto_url').on('input', onAutoUrlInput);
$('#sd_auto_auth').on('input', onAutoAuthInput);
$('#sd_vlad_validate').on('click', validateVladUrl);
$('#sd_vlad_url').on('input', onVladUrlInput);
$('#sd_vlad_auth').on('input', onVladAuthInput);
$('#sd_hr_upscaler').on('change', onHrUpscalerChange);
$('#sd_hr_scale').on('input', onHrScaleInput);
$('#sd_denoising_strength').on('input', onDenoisingStrengthInput);

View File

@ -17,6 +17,7 @@
<option value="extras">Extras API (local / remote)</option>
<option value="horde">Stable Horde</option>
<option value="auto">Stable Diffusion Web UI (AUTOMATIC1111)</option>
<option value="vlad">SD.Next (vladmandic)</option>
<option value="novel">NovelAI Diffusion</option>
</select>
<div data-sd-source="auto">
@ -34,6 +35,21 @@
<input id="sd_auto_auth" type="text" class="text_pole" placeholder="Example: username:password" value="" />
<i><b>Important:</b> run SD Web UI with the <tt>--api</tt> flag! The server must be accessible from the SillyTavern host machine.</i>
</div>
<div data-sd-source="vlad">
<label for="sd_vlad_url">SD.Next API URL</label>
<div class="flex-container flexnowrap">
<input id="sd_vlad_url" type="text" class="text_pole" placeholder="Example: {{vlad_url}}" value="{{vlad_url}}" />
<div id="sd_vlad_validate" class="menu_button menu_button_icon">
<i class="fa-solid fa-check"></i>
<span data-i18n="Connect">
Connect
</span>
</div>
</div>
<label for="sd_vlad_auth">Authentication (optional)</label>
<input id="sd_vlad_auth" type="text" class="text_pole" placeholder="Example: username:password" value="" />
<i>The server must be accessible from the SillyTavern host machine.</i>
</div>
<div data-sd-source="horde">
<i>Hint: Save an API key in Horde KoboldAI API settings to use it here.</i>
<label for="sd_horde_nsfw" class="checkbox_label">
@ -86,7 +102,7 @@
Hires. Fix
</label>
</div>
<div data-sd-source="auto">
<div data-sd-source="auto,vlad">
<label for="sd_hr_upscaler">Upscaler</label>
<select id="sd_hr_upscaler"></select>
<label for="sd_hr_scale">Upscale by (<span id="sd_hr_scale_value"></span>)</label>
@ -94,7 +110,7 @@
<label for="sd_denoising_strength">Denoising strength (<span id="sd_denoising_strength_value"></span>)</label>
<input id="sd_denoising_strength" type="range" min="{{denoising_strength_min}}" max="{{denoising_strength_max}}" step="{{denoising_strength_step}}" value="{{denoising_strength}}" />
<label for="sd_hr_second_pass_steps">Hires steps (2nd pass) (<span id="sd_hr_second_pass_steps_value"></span>)</label>
<input id="sd_hr_second_pass_steps" type="range" min="{{hr_second_pass_steps_min}}" max="{{hr_second_pass_steps_max}}" step="{{hr_second_pass_steps_max}}" value="{{hr_second_pass_steps}}" />
<input id="sd_hr_second_pass_steps" type="range" min="{{hr_second_pass_steps_min}}" max="{{hr_second_pass_steps_max}}" step="{{hr_second_pass_steps_step}}" value="{{hr_second_pass_steps}}" />
</div>
<div data-sd-source="novel">
<label for="sd_novel_upscale_ratio">Upscale by (<span id="sd_novel_upscale_ratio_value"></span>)</label>

View File

@ -10,7 +10,7 @@ function registerEndpoints(app, jsonParser) {
app.post('/api/sd/ping', jsonParser, async (request, response) => {
try {
const url = new URL(request.body.url);
url.pathname = '/internal/ping';
url.pathname = '/sdapi/v1/options';
const result = await fetch(url, {
method: 'GET',
@ -243,6 +243,38 @@ function registerEndpoints(app, jsonParser) {
return response.sendStatus(500);
}
});
app.post('/api/sd-next/upscalers', jsonParser, async (request, response) => {
try {
const url = new URL(request.body.url);
url.pathname = '/sdapi/v1/upscalers';
const result = await fetch(url, {
method: 'GET',
headers: {
'Authorization': getBasicAuthHeader(request.body.auth),
},
});
if (!result.ok) {
throw new Error('SD WebUI returned an error.');
}
// Vlad doesn't provide Latent Upscalers in the API, so we have to hardcode them here
const latentUpscalers = ['Latent', 'Latent (antialiased)', 'Latent (bicubic)', 'Latent (bicubic antialiased)', 'Latent (nearest)', 'Latent (nearest-exact)'];
const data = await result.json();
const names = data.map(x => x.name);
// 0 = None, then Latent Upscalers, then Upscalers
names.splice(1, 0, ...latentUpscalers);
return response.send(names);
} catch (error) {
console.log(error);
return response.sendStatus(500);
}
});
}
module.exports = {