mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Added stream support to "custom-request"
This commit is contained in:
@@ -3,10 +3,13 @@ import { extractMessageFromData, getGenerateUrl, getRequestHeaders } from '../sc
|
||||
import { getTextGenServer } from './textgen-settings.js';
|
||||
import { extractReasoningFromData } from './reasoning.js';
|
||||
import { formatInstructModeChat, formatInstructModePrompt, names_behavior_types } from './instruct-mode.js';
|
||||
import { getStreamingReply, tryParseStreamingError } from './openai.js';
|
||||
import EventSourceStream from './sse-stream.js';
|
||||
|
||||
// #region Type Definitions
|
||||
/**
|
||||
* @typedef {Object} TextCompletionRequestBase
|
||||
* @property {boolean?} [stream=false] - Whether to stream the response
|
||||
* @property {number} max_tokens - Maximum number of tokens to generate
|
||||
* @property {string} [model] - Optional model name
|
||||
* @property {string} api_type - Type of API to use
|
||||
@@ -17,6 +20,7 @@ import { formatInstructModeChat, formatInstructModePrompt, names_behavior_types
|
||||
|
||||
/**
|
||||
* @typedef {Object} TextCompletionPayloadBase
|
||||
* @property {boolean?} [stream=false] - Whether to stream the response
|
||||
* @property {string} prompt - The text prompt for completion
|
||||
* @property {number} max_tokens - Maximum number of tokens to generate
|
||||
* @property {number} max_new_tokens - Alias for max_tokens
|
||||
@@ -36,6 +40,7 @@ import { formatInstructModeChat, formatInstructModePrompt, names_behavior_types
|
||||
|
||||
/**
|
||||
* @typedef {Object} ChatCompletionPayloadBase
|
||||
* @property {boolean?} [stream=false] - Whether to stream the response
|
||||
* @property {ChatCompletionMessage[]} messages - Array of chat messages
|
||||
* @property {string} [model] - Optional model name to use for completion
|
||||
* @property {string} chat_completion_source - Source provider for chat completion
|
||||
@@ -52,10 +57,20 @@ import { formatInstructModeChat, formatInstructModePrompt, names_behavior_types
|
||||
* @property {string} reasoning - Extracted reasoning.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} StreamResponse
|
||||
* @property {string} text - Generated text.
|
||||
* @property {string[]} swipes - Generated swipes
|
||||
* @property {Object} state - Generated state
|
||||
* @property {string?} [state.reasoning] - Generated reasoning
|
||||
* @property {string?} [state.image] - Generated image
|
||||
* @returns {StreamResponse}
|
||||
*/
|
||||
|
||||
// #endregion
|
||||
|
||||
/**
|
||||
* Creates & sends a text completion request. Streaming is not supported.
|
||||
* Creates & sends a text completion request.
|
||||
*/
|
||||
export class TextCompletionService {
|
||||
static TYPE = 'textgenerationwebui';
|
||||
@@ -64,9 +79,10 @@ export class TextCompletionService {
|
||||
* @param {Record<string, any> & TextCompletionRequestBase & {prompt: string}} custom
|
||||
* @returns {TextCompletionPayload}
|
||||
*/
|
||||
static createRequestData({ prompt, max_tokens, model, api_type, api_server, temperature, min_p, ...props }) {
|
||||
static createRequestData({ stream = false, prompt, max_tokens, model, api_type, api_server, temperature, min_p, ...props }) {
|
||||
const payload = {
|
||||
...props,
|
||||
stream,
|
||||
prompt,
|
||||
max_tokens,
|
||||
max_new_tokens: max_tokens,
|
||||
@@ -75,7 +91,6 @@ export class TextCompletionService {
|
||||
api_server: api_server ?? getTextGenServer(api_type),
|
||||
temperature,
|
||||
min_p,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
// Remove undefined values to avoid API errors
|
||||
@@ -92,34 +107,81 @@ export class TextCompletionService {
|
||||
* Sends a text completion request to the specified server
|
||||
* @param {TextCompletionPayload} data Request data
|
||||
* @param {boolean?} extractData Extract message from the response. Default true
|
||||
* @returns {Promise<ExtractedData | any>} Extracted data or the raw response
|
||||
* @param {AbortSignal?} signal
|
||||
* @returns {Promise<ExtractedData | (() => AsyncGenerator<StreamResponse>)>} If not streaming, returns extracted data; if streaming, returns a function that creates an AsyncGenerator
|
||||
* @throws {Error}
|
||||
*/
|
||||
static async sendRequest(data, extractData = true) {
|
||||
const response = await fetch(getGenerateUrl(this.TYPE), {
|
||||
static async sendRequest(data, extractData = true, signal = null) {
|
||||
if (!data.stream) {
|
||||
const response = await fetch(getGenerateUrl(this.TYPE), {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
cache: 'no-cache',
|
||||
body: JSON.stringify(data),
|
||||
signal: signal ?? new AbortController().signal,
|
||||
});
|
||||
|
||||
const json = await response.json();
|
||||
if (!response.ok || json.error) {
|
||||
throw json;
|
||||
}
|
||||
|
||||
if (!extractData) {
|
||||
return json;
|
||||
}
|
||||
|
||||
return {
|
||||
content: extractMessageFromData(json, this.TYPE),
|
||||
reasoning: extractReasoningFromData(json, {
|
||||
mainApi: this.TYPE,
|
||||
textGenType: data.api_type,
|
||||
ignoreShowThoughts: true,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
const response = await fetch('/api/backends/text-completions/generate', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
cache: 'no-cache',
|
||||
body: JSON.stringify(data),
|
||||
signal: new AbortController().signal,
|
||||
signal: signal ?? new AbortController().signal,
|
||||
});
|
||||
|
||||
const json = await response.json();
|
||||
if (!response.ok || json.error) {
|
||||
throw json;
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
tryParseStreamingError(response, text, true);
|
||||
|
||||
throw new Error(`Got response status ${response.status}`);
|
||||
}
|
||||
|
||||
if (!extractData) {
|
||||
return json;
|
||||
}
|
||||
const eventStream = new EventSourceStream();
|
||||
response.body.pipeThrough(eventStream);
|
||||
const reader = eventStream.readable.getReader();
|
||||
return async function* streamData() {
|
||||
let text = '';
|
||||
const swipes = [];
|
||||
const state = { reasoning: '' };
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
if (value.data === '[DONE]') return;
|
||||
|
||||
return {
|
||||
content: extractMessageFromData(json, this.TYPE),
|
||||
reasoning: extractReasoningFromData(json, {
|
||||
mainApi: this.TYPE,
|
||||
textGenType: data.api_type,
|
||||
ignoreShowThoughts: true,
|
||||
}),
|
||||
tryParseStreamingError(response, value.data, true);
|
||||
|
||||
let data = JSON.parse(value.data);
|
||||
|
||||
if (data?.choices?.[0]?.index > 0) {
|
||||
const swipeIndex = data.choices[0].index - 1;
|
||||
swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text;
|
||||
} else {
|
||||
const newText = data?.choices?.[0]?.text || data?.content || '';
|
||||
text += newText;
|
||||
state.reasoning += data?.choices?.[0]?.reasoning ?? '';
|
||||
}
|
||||
|
||||
yield { text, swipes, state };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -130,13 +192,15 @@ export class TextCompletionService {
|
||||
* @param {string?} [options.presetName] - Name of the preset to use for generation settings
|
||||
* @param {string?} [options.instructName] - Name of instruct preset for message formatting
|
||||
* @param {boolean} extractData - Whether to extract structured data from response
|
||||
* @returns {Promise<ExtractedData | any>} Extracted data or the raw response
|
||||
* @param {AbortSignal?} [signal]
|
||||
* @returns {Promise<ExtractedData | (() => AsyncGenerator<StreamResponse>)>} If not streaming, returns extracted data; if streaming, returns a function that creates an AsyncGenerator
|
||||
* @throws {Error}
|
||||
*/
|
||||
static async processRequest(
|
||||
custom,
|
||||
options = {},
|
||||
extractData = true,
|
||||
signal = null,
|
||||
) {
|
||||
const { presetName, instructName } = options;
|
||||
let requestData = { ...custom };
|
||||
@@ -220,7 +284,7 @@ export class TextCompletionService {
|
||||
// @ts-ignore
|
||||
const data = this.createRequestData(requestData);
|
||||
|
||||
return await this.sendRequest(data, extractData);
|
||||
return await this.sendRequest(data, extractData, signal);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -256,7 +320,7 @@ export class TextCompletionService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates & sends a chat completion request. Streaming is not supported.
|
||||
* Creates & sends a chat completion request.
|
||||
*/
|
||||
export class ChatCompletionService {
|
||||
static TYPE = 'openai';
|
||||
@@ -265,16 +329,16 @@ export class ChatCompletionService {
|
||||
* @param {ChatCompletionPayload} custom
|
||||
* @returns {ChatCompletionPayload}
|
||||
*/
|
||||
static createRequestData({ messages, model, chat_completion_source, max_tokens, temperature, custom_url, ...props }) {
|
||||
static createRequestData({ stream = false, messages, model, chat_completion_source, max_tokens, temperature, custom_url, ...props }) {
|
||||
const payload = {
|
||||
...props,
|
||||
stream,
|
||||
messages,
|
||||
model,
|
||||
chat_completion_source,
|
||||
max_tokens,
|
||||
temperature,
|
||||
custom_url,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
// Remove undefined values to avoid API errors
|
||||
@@ -291,34 +355,74 @@ export class ChatCompletionService {
|
||||
* Sends a chat completion request
|
||||
* @param {ChatCompletionPayload} data Request data
|
||||
* @param {boolean?} extractData Extract message from the response. Default true
|
||||
* @returns {Promise<ExtractedData | any>} Extracted data or the raw response
|
||||
* @param {AbortSignal?} signal Abort signal
|
||||
* @returns {Promise<ExtractedData | (() => AsyncGenerator<StreamResponse>)>} If not streaming, returns extracted data; if streaming, returns a function that creates an AsyncGenerator
|
||||
* @throws {Error}
|
||||
*/
|
||||
static async sendRequest(data, extractData = true) {
|
||||
static async sendRequest(data, extractData = true, signal = null) {
|
||||
const response = await fetch('/api/backends/chat-completions/generate', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
cache: 'no-cache',
|
||||
body: JSON.stringify(data),
|
||||
signal: new AbortController().signal,
|
||||
signal: signal ?? new AbortController().signal,
|
||||
});
|
||||
|
||||
const json = await response.json();
|
||||
if (!response.ok || json.error) {
|
||||
throw json;
|
||||
if (!data.stream) {
|
||||
const json = await response.json();
|
||||
if (!response.ok || json.error) {
|
||||
throw json;
|
||||
}
|
||||
|
||||
if (!extractData) {
|
||||
return json;
|
||||
}
|
||||
|
||||
return {
|
||||
content: extractMessageFromData(json, this.TYPE),
|
||||
reasoning: extractReasoningFromData(json, {
|
||||
mainApi: this.TYPE,
|
||||
textGenType: data.chat_completion_source,
|
||||
ignoreShowThoughts: true,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
if (!extractData) {
|
||||
return json;
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
tryParseStreamingError(response, text, true);
|
||||
|
||||
throw new Error(`Got response status ${response.status}`);
|
||||
}
|
||||
|
||||
return {
|
||||
content: extractMessageFromData(json, this.TYPE),
|
||||
reasoning: extractReasoningFromData(json, {
|
||||
mainApi: this.TYPE,
|
||||
textGenType: data.chat_completion_source,
|
||||
ignoreShowThoughts: true,
|
||||
}),
|
||||
const eventStream = new EventSourceStream();
|
||||
response.body.pipeThrough(eventStream);
|
||||
const reader = eventStream.readable.getReader();
|
||||
return async function* streamData() {
|
||||
let text = '';
|
||||
const swipes = [];
|
||||
const state = { reasoning: '', image: '' };
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
const rawData = value.data;
|
||||
if (rawData === '[DONE]') return;
|
||||
tryParseStreamingError(response, rawData, true);
|
||||
const parsed = JSON.parse(rawData);
|
||||
|
||||
const reply = getStreamingReply(parsed, state, {
|
||||
chatCompletionSource: data.chat_completion_source,
|
||||
ignoreShowThoughts: true,
|
||||
});
|
||||
if (Array.isArray(parsed?.choices) && parsed?.choices?.[0]?.index > 0) {
|
||||
const swipeIndex = parsed.choices[0].index - 1;
|
||||
swipes[swipeIndex] = (swipes[swipeIndex] || '') + reply;
|
||||
} else {
|
||||
text += reply;
|
||||
}
|
||||
|
||||
yield { text, swipes: swipes, state };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -327,11 +431,12 @@ export class ChatCompletionService {
|
||||
* @param {ChatCompletionPayload} custom
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {string?} [options.presetName] - Name of the preset to use for generation settings
|
||||
* @param {boolean} extractData - Whether to extract structured data from response
|
||||
* @returns {Promise<ExtractedData | any>} Extracted data or the raw response
|
||||
* @param {boolean} [extractData=true] - Whether to extract structured data from response
|
||||
* @param {AbortSignal?} [signal] - Abort signal
|
||||
* @returns {Promise<ExtractedData | (() => AsyncGenerator<StreamResponse>)>} If not streaming, returns extracted data; if streaming, returns a function that creates an AsyncGenerator
|
||||
* @throws {Error}
|
||||
*/
|
||||
static async processRequest(custom, options, extractData = true) {
|
||||
static async processRequest(custom, options, extractData = true, signal = null) {
|
||||
const { presetName } = options;
|
||||
let requestData = { ...custom };
|
||||
|
||||
@@ -354,7 +459,7 @@ export class ChatCompletionService {
|
||||
|
||||
const data = this.createRequestData(requestData);
|
||||
|
||||
return await this.sendRequest(data, extractData);
|
||||
return await this.sendRequest(data, extractData, signal);
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user