Merge pull request #2842 from SillyTavern/o1

OpenAI O1
This commit is contained in:
Cohee 2024-09-13 21:05:00 +03:00 committed by GitHub
commit 5dd1d26350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 74 additions and 65 deletions

9
package-lock.json generated
View File

@ -46,7 +46,7 @@
"sanitize-filename": "^1.6.3", "sanitize-filename": "^1.6.3",
"sillytavern-transformers": "2.14.6", "sillytavern-transformers": "2.14.6",
"simple-git": "^3.19.1", "simple-git": "^3.19.1",
"tiktoken": "^1.0.15", "tiktoken": "^1.0.16",
"vectra": "^0.2.2", "vectra": "^0.2.2",
"wavefile": "^11.0.0", "wavefile": "^11.0.0",
"write-file-atomic": "^5.0.1", "write-file-atomic": "^5.0.1",
@ -5751,9 +5751,10 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/tiktoken": { "node_modules/tiktoken": {
"version": "1.0.15", "version": "1.0.16",
"resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.15.tgz", "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.16.tgz",
"integrity": "sha512-sCsrq/vMWUSEW29CJLNmPvWxlVp7yh2tlkAjpJltIKqp5CKf98ZNpdeHRmAlPVFlGEbswDc6SmI8vz64W/qErw==" "integrity": "sha512-hRcORIGF2YlAgWx3nzrGJOrKSJwLoc81HpXmMQk89632XAgURc7IeV2FgQ2iXo9z/J96fCvpsHg2kWoHcbj9fg==",
"license": "MIT"
}, },
"node_modules/timm": { "node_modules/timm": {
"version": "1.7.1", "version": "1.7.1",

View File

@ -36,7 +36,7 @@
"sanitize-filename": "^1.6.3", "sanitize-filename": "^1.6.3",
"sillytavern-transformers": "2.14.6", "sillytavern-transformers": "2.14.6",
"simple-git": "^3.19.1", "simple-git": "^3.19.1",
"tiktoken": "^1.0.15", "tiktoken": "^1.0.16",
"vectra": "^0.2.2", "vectra": "^0.2.2",
"wavefile": "^11.0.0", "wavefile": "^11.0.0",
"write-file-atomic": "^5.0.1", "write-file-atomic": "^5.0.1",

View File

@ -383,7 +383,7 @@
Max Response Length (tokens) Max Response Length (tokens)
</div> </div>
<div class="wide100p"> <div class="wide100p">
<input type="number" id="openai_max_tokens" name="openai_max_tokens" class="text_pole" min="1" max="16384"> <input type="number" id="openai_max_tokens" name="openai_max_tokens" class="text_pole" min="1" max="65536">
</div> </div>
</div> </div>
<div class="range-block" data-source="openai,custom"> <div class="range-block" data-source="openai,custom">
@ -2611,6 +2611,10 @@
<option value="gpt-4-0125-preview">gpt-4-0125-preview (2024)</option> <option value="gpt-4-0125-preview">gpt-4-0125-preview (2024)</option>
<option value="gpt-4-1106-preview">gpt-4-1106-preview (2023)</option> <option value="gpt-4-1106-preview">gpt-4-1106-preview (2023)</option>
</optgroup> </optgroup>
<optgroup label="o1">
<option value="o1-preview">o1-preview</option>
<option value="o1-mini">o1-mini</option>
</optgroup>
<optgroup label="Other"> <optgroup label="Other">
<option value="text-davinci-003">text-davinci-003</option> <option value="text-davinci-003">text-davinci-003</option>
<option value="text-davinci-002">text-davinci-002</option> <option value="text-davinci-002">text-davinci-002</option>

View File

@ -881,7 +881,6 @@ let abortController;
//css //css
var css_send_form_display = $('<div id=send_form></div>').css('display'); var css_send_form_display = $('<div id=send_form></div>').css('display');
const MAX_GENERATION_LOOPS = 5;
var kobold_horde_model = ''; var kobold_horde_model = '';
@ -2862,7 +2861,12 @@ export function getCharacterCardFields() {
export function isStreamingEnabled() { export function isStreamingEnabled() {
const noStreamSources = [chat_completion_sources.SCALE]; const noStreamSources = [chat_completion_sources.SCALE];
return ((main_api == 'openai' && oai_settings.stream_openai && !noStreamSources.includes(oai_settings.chat_completion_source) && !(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE && oai_settings.google_model.includes('bison'))) return (
(main_api == 'openai' &&
oai_settings.stream_openai &&
!noStreamSources.includes(oai_settings.chat_completion_source) &&
!(oai_settings.chat_completion_source == chat_completion_sources.OPENAI && oai_settings.openai_model.startsWith('o1-')) &&
!(oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE && oai_settings.google_model.includes('bison')))
|| (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming) || (main_api == 'kobold' && kai_settings.streaming_kobold && kai_flags.can_use_streaming)
|| (main_api == 'novel' && nai_settings.streaming_novel) || (main_api == 'novel' && nai_settings.streaming_novel)
|| (main_api == 'textgenerationwebui' && textgen_settings.streaming)); || (main_api == 'textgenerationwebui' && textgen_settings.streaming));
@ -3337,11 +3341,11 @@ function removeLastMessage() {
* @param {GenerateOptions} options Generation options * @param {GenerateOptions} options Generation options
* @param {boolean} dryRun Whether to actually generate a message or just assemble the prompt * @param {boolean} dryRun Whether to actually generate a message or just assemble the prompt
* @returns {Promise<any>} Returns a promise that resolves when the text is done generating. * @returns {Promise<any>} Returns a promise that resolves when the text is done generating.
* @typedef {{automatic_trigger?: boolean, force_name2?: boolean, quiet_prompt?: string, quietToLoud?: boolean, skipWIAN?: boolean, force_chid?: number, signal?: AbortSignal, quietImage?: string, maxLoops?: number, quietName?: string }} GenerateOptions * @typedef {{automatic_trigger?: boolean, force_name2?: boolean, quiet_prompt?: string, quietToLoud?: boolean, skipWIAN?: boolean, force_chid?: number, signal?: AbortSignal, quietImage?: string, quietName?: string }} GenerateOptions
*/ */
export async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, maxLoops, quietName } = {}, dryRun = false) { export async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName } = {}, dryRun = false) {
console.log('Generate entered'); console.log('Generate entered');
await eventSource.emit(event_types.GENERATION_STARTED, type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, maxLoops }, dryRun); await eventSource.emit(event_types.GENERATION_STARTED, type, { automatic_trigger, force_name2, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage }, dryRun);
setGenerationProgress(0); setGenerationProgress(0);
generation_started = new Date(); generation_started = new Date();
@ -3403,7 +3407,7 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
if (selected_group && !is_group_generating) { if (selected_group && !is_group_generating) {
if (!dryRun) { if (!dryRun) {
// Returns the promise that generateGroupWrapper returns; resolves when generation is done // Returns the promise that generateGroupWrapper returns; resolves when generation is done
return generateGroupWrapper(false, type, { quiet_prompt, force_chid, signal: abortController.signal, quietImage, maxLoops }); return generateGroupWrapper(false, type, { quiet_prompt, force_chid, signal: abortController.signal, quietImage });
} }
const characterIndexMap = new Map(characters.map((char, index) => [char.avatar, index])); const characterIndexMap = new Map(characters.map((char, index) => [char.avatar, index]));
@ -4435,53 +4439,30 @@ export async function Generate(type, { automatic_trigger, force_name2, quiet_pro
const displayIncomplete = type === 'quiet' && !quietToLoud; const displayIncomplete = type === 'quiet' && !quietToLoud;
getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete); getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete);
if (getMessage.length > 0 || data.allowEmptyResponse) { if (isImpersonate) {
if (isImpersonate) { $('#send_textarea').val(getMessage)[0].dispatchEvent(new Event('input', { bubbles: true }));
$('#send_textarea').val(getMessage)[0].dispatchEvent(new Event('input', { bubbles: true })); generatedPromptCache = '';
generatedPromptCache = ''; await eventSource.emit(event_types.IMPERSONATE_READY, getMessage);
await eventSource.emit(event_types.IMPERSONATE_READY, getMessage); }
} else if (type == 'quiet') {
else if (type == 'quiet') { unblockGeneration(type);
unblockGeneration(type); return getMessage;
return getMessage; }
else {
// Without streaming we'll be having a full message on continuation. Treat it as a last chunk.
if (originalType !== 'continue') {
({ type, getMessage } = await saveReply(type, getMessage, false, title, swipes));
} }
else { else {
// Without streaming we'll be having a full message on continuation. Treat it as a last chunk. ({ type, getMessage } = await saveReply('appendFinal', getMessage, false, title, swipes));
if (originalType !== 'continue') {
({ type, getMessage } = await saveReply(type, getMessage, false, title, swipes));
}
else {
({ type, getMessage } = await saveReply('appendFinal', getMessage, false, title, swipes));
}
// This relies on `saveReply` having been called to add the message to the chat, so it must be last.
parseAndSaveLogprobs(data, continue_mag);
} }
if (type !== 'quiet') { // This relies on `saveReply` having been called to add the message to the chat, so it must be last.
playMessageSound(); parseAndSaveLogprobs(data, continue_mag);
} }
} else {
// If maxLoops is not passed in (e.g. first time generating), set it to MAX_GENERATION_LOOPS
maxLoops ??= MAX_GENERATION_LOOPS;
if (maxLoops === 0) { if (type !== 'quiet') {
if (type !== 'quiet') { playMessageSound();
throwCircuitBreakerError();
}
throw new Error('Generate circuit breaker interruption');
}
// regenerate with character speech reenforced
// to make sure we leave on swipe type while also adding the name2 appendage
await delay(1000);
// A message was already deleted on regeneration, so instead treat is as a normal gen
if (type === 'regenerate') {
type = 'normal';
}
// The first await is for waiting for the generate to start. The second one is waiting for it to finish
const result = await await Generate(type, { automatic_trigger, force_name2: true, quiet_prompt, quietToLoud, skipWIAN, force_chid, signal, quietImage, quietName, maxLoops: maxLoops - 1 });
return result;
} }
if (power_user.auto_swipe) { if (power_user.auto_swipe) {
@ -5254,11 +5235,6 @@ function getGenerateUrl(api) {
} }
} }
function throwCircuitBreakerError() {
callPopup(`Could not extract reply in ${MAX_GENERATION_LOOPS} attempts. Try generating again`, 'text');
unblockGeneration();
}
function extractTitleFromData(data) { function extractTitleFromData(data) {
if (main_api == 'koboldhorde') { if (main_api == 'koboldhorde') {
return data.workerName; return data.workerName;

View File

@ -1797,7 +1797,7 @@ async function sendOpenAIRequest(type, messages, signal) {
const isQuiet = type === 'quiet'; const isQuiet = type === 'quiet';
const isImpersonate = type === 'impersonate'; const isImpersonate = type === 'impersonate';
const isContinue = type === 'continue'; const isContinue = type === 'continue';
const stream = oai_settings.stream_openai && !isQuiet && !isScale && !(isGoogle && oai_settings.google_model.includes('bison')); const stream = oai_settings.stream_openai && !isQuiet && !isScale && !(isGoogle && oai_settings.google_model.includes('bison')) && !(isOAI && oai_settings.openai_model.startsWith('o1-'));
const useLogprobs = !!power_user.request_token_probabilities; const useLogprobs = !!power_user.request_token_probabilities;
const canMultiSwipe = oai_settings.n > 1 && !isContinue && !isImpersonate && !isQuiet && (isOAI || isCustom); const canMultiSwipe = oai_settings.n > 1 && !isContinue && !isImpersonate && !isQuiet && (isOAI || isCustom);
@ -1960,12 +1960,35 @@ async function sendOpenAIRequest(type, messages, signal) {
generate_data['seed'] = oai_settings.seed; generate_data['seed'] = oai_settings.seed;
} }
await eventSource.emit(event_types.CHAT_COMPLETION_SETTINGS_READY, generate_data);
if (isFunctionCallingSupported() && !stream) { if (isFunctionCallingSupported() && !stream) {
await registerFunctionTools(type, generate_data); await registerFunctionTools(type, generate_data);
} }
if (isOAI && oai_settings.openai_model.startsWith('o1-')) {
generate_data.messages.forEach((msg) => {
if (msg.role === 'system') {
msg.role = 'user';
}
});
generate_data.max_completion_tokens = generate_data.max_tokens;
delete generate_data.max_tokens;
delete generate_data.stream;
delete generate_data.logprobs;
delete generate_data.top_logprobs;
delete generate_data.n;
delete generate_data.temperature;
delete generate_data.top_p;
delete generate_data.frequency_penalty;
delete generate_data.presence_penalty;
delete generate_data.tools;
delete generate_data.tool_choice;
delete generate_data.stop;
// It does support logit_bias, but the tokenizer used and its effect is yet unknown.
// delete generate_data.logit_bias;
}
await eventSource.emit(event_types.CHAT_COMPLETION_SETTINGS_READY, generate_data);
const generate_url = '/api/backends/chat-completions/generate'; const generate_url = '/api/backends/chat-completions/generate';
const response = await fetch(generate_url, { const response = await fetch(generate_url, {
method: 'POST', method: 'POST',
@ -2111,7 +2134,6 @@ async function checkFunctionToolCalls(data) {
const args = toolCall.function; const args = toolCall.function;
console.log('Function tool call:', toolCall); console.log('Function tool call:', toolCall);
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
data.allowEmptyResponse = true;
} }
} }
@ -2125,7 +2147,6 @@ async function checkFunctionToolCalls(data) {
/** @type {FunctionToolCall} */ /** @type {FunctionToolCall} */
const args = { name: content.name, arguments: JSON.stringify(content.input) }; const args = { name: content.name, arguments: JSON.stringify(content.input) };
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
data.allowEmptyResponse = true;
} }
} }
} }
@ -2140,7 +2161,6 @@ async function checkFunctionToolCalls(data) {
const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) }; const args = { name: toolCall.name, arguments: JSON.stringify(toolCall.parameters) };
console.log('Function tool call:', toolCall); console.log('Function tool call:', toolCall);
await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args); await eventSource.emit(event_types.LLM_FUNCTION_TOOL_CALL, args);
data.allowEmptyResponse = true;
} }
} }
} }
@ -3905,6 +3925,9 @@ function getMaxContextOpenAI(value) {
if (oai_settings.max_context_unlocked) { if (oai_settings.max_context_unlocked) {
return unlocked_max; return unlocked_max;
} }
else if (value.startsWith('o1-')) {
return max_128k;
}
else if (value.includes('chatgpt-4o-latest') || value.includes('gpt-4-turbo') || value.includes('gpt-4o') || value.includes('gpt-4-1106') || value.includes('gpt-4-0125') || value.includes('gpt-4-vision')) { else if (value.includes('chatgpt-4o-latest') || value.includes('gpt-4-turbo') || value.includes('gpt-4o') || value.includes('gpt-4-1106') || value.includes('gpt-4-0125') || value.includes('gpt-4-vision')) {
return max_128k; return max_128k;
} }

View File

@ -965,6 +965,7 @@ router.post('/generate', jsonParser, function (request, response) {
'model': request.body.model, 'model': request.body.model,
'temperature': request.body.temperature, 'temperature': request.body.temperature,
'max_tokens': request.body.max_tokens, 'max_tokens': request.body.max_tokens,
'max_completion_tokens': request.body.max_completion_tokens,
'stream': request.body.stream, 'stream': request.body.stream,
'presence_penalty': request.body.presence_penalty, 'presence_penalty': request.body.presence_penalty,
'frequency_penalty': request.body.frequency_penalty, 'frequency_penalty': request.body.frequency_penalty,

View File

@ -350,6 +350,10 @@ function getWebTokenizersChunks(tokenizer, ids) {
* @returns {string} Tokenizer model to use * @returns {string} Tokenizer model to use
*/ */
function getTokenizerModel(requestModel) { function getTokenizerModel(requestModel) {
if (requestModel.includes('o1-preview') || requestModel.includes('o1-mini')) {
return 'gpt-4o';
}
if (requestModel.includes('gpt-4o')) { if (requestModel.includes('gpt-4o')) {
return 'gpt-4o'; return 'gpt-4o';
} }