Merge pull request #2590 from splitclover/patch-1

Add eventSource for generate_data, export functions for streaming/generation request
This commit is contained in:
Cohee 2024-07-30 17:18:41 +03:00 committed by GitHub
commit f48de733fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 2 deletions

View File

@ -439,6 +439,7 @@ export const event_types = {
GROUP_CHAT_CREATED: 'group_chat_created',
GENERATE_BEFORE_COMBINE_PROMPTS: 'generate_before_combine_prompts',
GENERATE_AFTER_COMBINE_PROMPTS: 'generate_after_combine_prompts',
GENERATE_AFTER_DATA: 'generate_after_data',
GROUP_MEMBER_DRAFTED: 'group_member_drafted',
WORLD_INFO_ACTIVATED: 'world_info_activated',
TEXT_COMPLETION_SETTINGS_READY: 'text_completion_settings_ready',
@ -4215,6 +4216,8 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
}
}
await eventSource.emit(event_types.GENERATE_AFTER_DATA, generate_data);
if (dryRun) {
generatedPromptCache = '';
return Promise.resolve();
@ -5078,7 +5081,7 @@ function setInContextMessages(lastmsg, type) {
* @param {object} data Generation data
* @returns {Promise<object>} Response data from the API
*/
async function sendGenerationRequest(type, data) {
export async function sendGenerationRequest(type, data) {
if (main_api === 'openai') {
return await sendOpenAIRequest(type, data.prompt, abortController.signal);
}
@ -5110,7 +5113,7 @@ async function sendGenerationRequest(type, data) {
* @param {object} data Generation data
* @returns {Promise<any>} Streaming generator
*/
async function sendStreamingRequest(type, data) {
export async function sendStreamingRequest(type, data) {
if (abortController?.signal?.aborted) {
throw new Error('Generation was aborted.');
}
@ -7921,6 +7924,8 @@ window['SillyTavern'].getContext = function () {
eventTypes: event_types,
addOneMessage: addOneMessage,
generate: Generate,
sendStreamingRequest: sendStreamingRequest,
sendGenerationRequest: sendGenerationRequest,
stopGeneration: stopGeneration,
getTokenCount: getTokenCount,
extensionPrompts: extension_prompts,