Add shared utilities for generating text with WebLLM

This commit is contained in:
Cohee 2024-08-12 22:07:44 +03:00
parent 2bdc6f27cc
commit 77ab694ea0
1 changed files with 83 additions and 0 deletions

View File

@ -176,3 +176,86 @@ function throwIfInvalidModel(useReverseProxy) {
throw new Error('Custom API URL is not set.'); throw new Error('Custom API URL is not set.');
} }
} }
/**
* Check if the WebLLM extension is installed and supported.
* @returns {boolean} Whether the extension is installed and supported
*/
export function isWebLlmSupported() {
if (!('gpu' in navigator)) {
toastr.error('Your browser does not support the WebGPU API. Please use a different browser.', 'WebLLM', {
preventDuplicates: true,
timeOut: 0,
extendedTimeOut: 0,
});
return false;
}
if (!('llm' in SillyTavern)) {
toastr.error('WebLLM extension is not installed. Click here to install it.', 'WebLLM', {
timeOut: 0,
extendedTimeOut: 0,
preventDuplicates: true,
onclick: () => {
const button = document.getElementById('third_party_extension_button');
if (button) {
button.click();
}
const input = document.querySelector('dialog textarea');
if (input instanceof HTMLTextAreaElement) {
input.value = 'https://github.com/SillyTavern/Extension-WebLLM';
}
},
});
return false;
}
return true;
}
/**
* Generates text in response to a chat prompt using WebLLM.
* @param {any[]} messages Messages to use for generating
* @returns {Promise<string>} Generated response
*/
export async function generateWebLlmChatPrompt(messages) {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
const response = await engine.generateChatPrompt(messages);
return response;
}
/**
* Counts the number of tokens in the provided text using WebLLM's default model.
* @param {string} text Text to count tokens in
* @returns {Promise<number>} Number of tokens in the text
*/
export async function countWebLlmTokens(text) {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
const response = await engine.countTokens(text);
return response;
}
/**
* Gets the size of the context in the WebLLM's default model.
* @returns {Promise<number>} Size of the context in the WebLLM model
*/
export async function getWebLlmContextSize() {
if (!isWebLlmSupported()) {
throw new Error('WebLLM extension is not installed.');
}
const engine = SillyTavern.llm;
await engine.loadModel();
const model = await engine.getCurrentModelInfo();
return model?.context_size;
}