diff --git a/public/script.js b/public/script.js index 603cf7033..7efe1d231 100644 --- a/public/script.js +++ b/public/script.js @@ -6,6 +6,7 @@ import { loadKoboldSettings, formatKoboldUrl, getKoboldGenerationData, + canUseKoboldStopSequence, } from "./scripts/kai-settings.js"; import { @@ -518,6 +519,11 @@ async function getStatus() { is_pygmalion = false; } + // determine if we can use stop sequence + if (main_api == "kobold") { + kai_settings.use_stop_sequence = canUseKoboldStopSequence(data.version); + } + // determine if streaming is enabled for ooba if (main_api == 'textgenerationwebui' && typeof data.gradio_config == 'string') { try { @@ -2196,6 +2202,9 @@ function saveReply(type, getMessage, this_mes_is_name) { type = 'normal'; } + const img = extractImageFromMessage(getMessage); + getMessage = img.getMessage; + if (type === 'swipe') { chat[chat.length - 1]['swipes'][chat[chat.length - 1]['swipes'].length] = getMessage; if (chat[chat.length - 1]['swipe_id'] === chat[chat.length - 1]['swipes'].length - 1) { @@ -2225,12 +2234,32 @@ function saveReply(type, getMessage, this_mes_is_name) { chat[chat.length - 1]['is_name'] = true; chat[chat.length - 1]['force_avatar'] = avatarImg; } - //console.log('runGenerate calls addOneMessage'); + + saveImageToMessage(img, chat[chat.length - 1]); addOneMessage(chat[chat.length - 1]); } return { type, getMessage }; } +function saveImageToMessage(img, mes) { + if (mes && img.image) { + if (typeof mes.extra !== 'object') { + mes.extra = {}; + } + mes.extra.image = img.image; + mes.title = img.title; + } +} + +function extractImageFromMessage(getMessage) { + const regex = //g; + const results = regex.exec(getMessage); + const image = results ? results[1] : ''; + const title = results ? results[2] : ''; + getMessage = getMessage.replace(regex, ''); + return { getMessage, image, title }; +} + function isMultigenEnabled() { return power_user.multigen && (main_api == 'textgenerationwebui' || main_api == 'kobold' || main_api == 'novel'); } diff --git a/public/scripts/kai-settings.js b/public/scripts/kai-settings.js index 9330a3b66..36701fcfd 100644 --- a/public/scripts/kai-settings.js +++ b/public/scripts/kai-settings.js @@ -8,6 +8,7 @@ export { loadKoboldSettings, formatKoboldUrl, getKoboldGenerationData, + canUseKoboldStopSequence, }; const kai_settings = { @@ -21,8 +22,11 @@ const kai_settings = { tfs: 1, rep_pen_slope: 0.9, single_line: false, + use_stop_sequence: false, }; +const MIN_STOP_SEQUENCE_VERSION = '1.2.2'; + function formatKoboldUrl(value) { try { const url = new URL(value); @@ -81,7 +85,7 @@ function getKoboldGenerationData(finalPromt, this_settings, this_amount_gen, thi s7: this_settings.sampler_order[6], use_world_info: false, singleline: kai_settings.single_line, - stop_sequence: [getStoppingStrings(isImpersonate, false)], + stop_sequence: kai_settings.use_stop_sequence ? [getStoppingStrings(isImpersonate, false)] : undefined, }; return generate_data; } @@ -152,6 +156,10 @@ const sliders = [ }, ]; +function canUseKoboldStopSequence(version) { + return version.localeCompare(MIN_STOP_SEQUENCE_VERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1; +} + $(document).ready(function () { sliders.forEach(slider => { $(document).on("input", slider.sliderId, function () { diff --git a/public/style.css b/public/style.css index f266592d5..98f71a34e 100644 --- a/public/style.css +++ b/public/style.css @@ -344,7 +344,6 @@ code { outline: none; border: none; position: relative; - display: inline; opacity: 0.7; cursor: pointer; z-index: 2001; @@ -352,6 +351,8 @@ code { padding-top: 0; transition: 0.3s; font-size: 30px; + display: flex; + align-items: center; } .font-family-reset { @@ -407,6 +408,7 @@ code { display: flex; align-items: center; column-gap: 10px; + cursor: pointer; } .options-content a div:first-child { diff --git a/server.js b/server.js index d51b6bff9..7adbd78d9 100644 --- a/server.js +++ b/server.js @@ -270,8 +270,10 @@ app.post("/generate", jsonParser, async function (request, response_generate = r typical: request.body.typical, sampler_order: sampler_order, singleline: !!request.body.singleline, - stop_sequence: request.body.stop_sequence || [], }; + if (!!request.body.stop_sequence) { + this_settings['stop_sequence'] = request.body.stop_sequence; + } } console.log(this_settings); @@ -415,7 +417,7 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r try { for await (const text of readWebsocket()) { - if (text == null) { + if (text == null || typeof text !== 'string') { break; } @@ -536,7 +538,7 @@ app.post("/getchat", jsonParser, function (request, response) { }); -app.post("/getstatus", jsonParser, function (request, response_getstatus = response) { +app.post("/getstatus", jsonParser, async function (request, response_getstatus = response) { if (!request.body) return response_getstatus.sendStatus(400); api_server = request.body.api_server; main_api = request.body.main_api; @@ -547,10 +549,19 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo headers: { "Content-Type": "application/json" } }; var url = api_server + "/v1/model"; + let version = ''; if (main_api == "textgenerationwebui") { url = api_server; args = {} } + if (main_api == "kobold") { + try { + version = (await getAsync(api_server + "/v1/info/version")).result; + } + catch { + version = '0.0.0'; + } + } client.get(url, args, function (data, response) { if (response.statusCode == 200) { if (main_api == "textgenerationwebui") { @@ -568,8 +579,8 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo data = { result: "no_connection" }; } } else { + data.version = version; if (data.result != "ReadOnly") { - //response_getstatus.send(data.result); } else { data.result = "no_connection"; } @@ -2106,8 +2117,9 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op response_generate_openai.end(); }); } else { - console.log(response.data); response_generate_openai.send(response.data); + console.log(response.data); + console.log(response.data?.choices[0]?.message); } } else if (response.status == 400) { console.log('Validation error');