mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-02-20 22:20:39 +01:00
Merge branch 'staging' into vectordb-with-extras
This commit is contained in:
commit
a1c7e2918b
@ -4360,7 +4360,7 @@
|
||||
</div>
|
||||
<div class="world_entry_form_control">
|
||||
<small class="textAlignCenter">Logic</small>
|
||||
<select name="entryLogicType" class="widthFitContent margin0">
|
||||
<select name="entryLogicType" class="text_pole widthFitContent margin0">
|
||||
<option value="0">AND ANY</option>
|
||||
<option value="3">AND ALL</option>
|
||||
<option value="1">NOT ALL</option>
|
||||
@ -4379,6 +4379,28 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div name="perEntryOverridesBlock" class="flex-container wide100p alignitemscenter">
|
||||
<div class="world_entry_form_control flex1">
|
||||
<small class="textAlignCenter">Scan Depth</small>
|
||||
<input class="text_pole" name="scanDepth" type="number" placeholder="Use global setting" max="100">
|
||||
</div>
|
||||
<div class="world_entry_form_control flex1">
|
||||
<small class="textAlignCenter">Case-Sensitive</small>
|
||||
<select name="caseSensitive" class="text_pole widthNatural margin0">
|
||||
<option value="null">Use global setting</option>
|
||||
<option value="true">Yes</option>
|
||||
<option value="false">No</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="world_entry_form_control flex1">
|
||||
<small class="textAlignCenter">Match Whole Words</small>
|
||||
<select name="matchWholeWords" class="text_pole widthNatural margin0">
|
||||
<option value="null">Use global setting</option>
|
||||
<option value="true">Yes</option>
|
||||
<option value="false">No</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div name="contentAndCharFilterBlock" class="world_entry_thin_controls flex2">
|
||||
<div class="world_entry_form_control flex1">
|
||||
<label for="content ">
|
||||
|
@ -1447,6 +1447,7 @@ async function printMessages() {
|
||||
}
|
||||
|
||||
async function clearChat() {
|
||||
closeMessageEditor();
|
||||
count_view_mes = 0;
|
||||
extension_prompts = {};
|
||||
if (is_delete_mode) {
|
||||
|
@ -1160,6 +1160,7 @@ function tryParseStreamingError(response, decoded) {
|
||||
}
|
||||
|
||||
checkQuotaError(data);
|
||||
checkModerationError(data);
|
||||
|
||||
if (data.error) {
|
||||
toastr.error(data.error.message || response.statusText, 'Chat Completion API');
|
||||
@ -1187,6 +1188,15 @@ function checkQuotaError(data) {
|
||||
}
|
||||
}
|
||||
|
||||
function checkModerationError(data) {
|
||||
const moderationError = data?.error?.message?.includes('requires moderation');
|
||||
if (moderationError) {
|
||||
const moderationReason = `Reasons: ${data?.error?.metadata?.reasons?.join(', ') ?? '(N/A)'}`;
|
||||
const flaggedText = data?.error?.metadata?.flagged_input ?? '(N/A)';
|
||||
toastr.info(flaggedText, moderationReason, { timeOut: 10000 });
|
||||
}
|
||||
}
|
||||
|
||||
async function sendWindowAIRequest(messages, signal, stream) {
|
||||
if (!('ai' in window)) {
|
||||
return showWindowExtensionError();
|
||||
@ -1688,6 +1698,7 @@ async function sendOpenAIRequest(type, messages, signal) {
|
||||
const data = await response.json();
|
||||
|
||||
checkQuotaError(data);
|
||||
checkModerationError(data);
|
||||
|
||||
if (data.error) {
|
||||
toastr.error(data.error.message || response.statusText, 'API returned an error');
|
||||
|
@ -664,7 +664,7 @@ function randValuesCallback(from, to, args) {
|
||||
if (args.round == 'floor') {
|
||||
return Math.floor(value);
|
||||
}
|
||||
return value;
|
||||
return String(value);
|
||||
}
|
||||
|
||||
export function registerVariableCommands() {
|
||||
|
@ -70,6 +70,135 @@ const SORT_ORDER_KEY = 'world_info_sort_order';
|
||||
const METADATA_KEY = 'world_info';
|
||||
|
||||
const DEFAULT_DEPTH = 4;
|
||||
const MAX_SCAN_DEPTH = 100;
|
||||
|
||||
/**
|
||||
* Represents a scanning buffer for one evaluation of World Info.
|
||||
*/
|
||||
class WorldInfoBuffer {
|
||||
// Typedef area
|
||||
/** @typedef {{scanDepth?: number, caseSensitive?: boolean, matchWholeWords?: boolean}} WIScanEntry The entry that triggered the scan */
|
||||
// End typedef area
|
||||
|
||||
/**
|
||||
* @type {string[]} Array of messages sorted by ascending depth
|
||||
*/
|
||||
#depthBuffer = [];
|
||||
|
||||
/**
|
||||
* @type {string[]} Array of strings added by recursive scanning
|
||||
*/
|
||||
#recurseBuffer = [];
|
||||
|
||||
/**
|
||||
* @type {number} The skew of the global scan depth. Used in "min activations"
|
||||
*/
|
||||
#skew = 0;
|
||||
|
||||
/**
|
||||
* Initialize the buffer with the given messages.
|
||||
* @param {string[]} messages Array of messages to add to the buffer
|
||||
*/
|
||||
constructor(messages) {
|
||||
this.#initDepthBuffer(messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Populates the buffer with the given messages.
|
||||
* @param {string[]} messages Array of messages to add to the buffer
|
||||
* @returns {void} Hardly seen nothing down here
|
||||
*/
|
||||
#initDepthBuffer(messages) {
|
||||
for (let depth = 0; depth < MAX_SCAN_DEPTH; depth++) {
|
||||
if (messages[depth]) {
|
||||
this.#depthBuffer[depth] = messages[depth].trim();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a string that respects the case sensitivity setting
|
||||
* @param {string} str The string to transform
|
||||
* @param {WIScanEntry} entry The entry that triggered the scan
|
||||
* @returns {string} The transformed string
|
||||
*/
|
||||
#transformString(str, entry) {
|
||||
const caseSensitive = entry.caseSensitive ?? world_info_case_sensitive;
|
||||
return caseSensitive ? str : str.toLowerCase();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all messages up to the given depth + recursion buffer.
|
||||
* @param {WIScanEntry} entry The entry that triggered the scan
|
||||
* @returns {string} A slice of buffer until the given depth (inclusive)
|
||||
*/
|
||||
get(entry) {
|
||||
let depth = entry.scanDepth ?? (world_info_depth + this.#skew);
|
||||
|
||||
if (depth < 0) {
|
||||
console.error(`Invalid WI scan depth ${depth}. Must be >= 0`);
|
||||
return '';
|
||||
}
|
||||
|
||||
if (depth > MAX_SCAN_DEPTH) {
|
||||
console.warn(`Invalid WI scan depth ${depth}. Truncating to ${MAX_SCAN_DEPTH}`);
|
||||
depth = MAX_SCAN_DEPTH;
|
||||
}
|
||||
|
||||
let result = this.#depthBuffer.slice(0, depth).join('\n');
|
||||
|
||||
if (this.#recurseBuffer.length > 0) {
|
||||
result += '\n' + this.#recurseBuffer.join('\n');
|
||||
}
|
||||
|
||||
return this.#transformString(result, entry);
|
||||
}
|
||||
|
||||
/**
|
||||
* Matches the given string against the buffer.
|
||||
* @param {string} haystack The string to search in
|
||||
* @param {string} needle The string to search for
|
||||
* @param {WIScanEntry} entry The entry that triggered the scan
|
||||
* @returns {boolean} True if the string was found in the buffer
|
||||
*/
|
||||
matchKeys(haystack, needle, entry) {
|
||||
const transformedString = this.#transformString(needle, entry);
|
||||
const matchWholeWords = entry.matchWholeWords ?? world_info_match_whole_words;
|
||||
|
||||
if (matchWholeWords) {
|
||||
const keyWords = transformedString.split(/\s+/);
|
||||
|
||||
if (keyWords.length > 1) {
|
||||
return haystack.includes(transformedString);
|
||||
}
|
||||
else {
|
||||
const regex = new RegExp(`\\b${escapeRegex(transformedString)}\\b`);
|
||||
if (regex.test(haystack)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return haystack.includes(transformedString);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a message to the recursion buffer.
|
||||
* @param {string} message The message to add
|
||||
*/
|
||||
addRecurse(message) {
|
||||
this.#recurseBuffer.push(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an increment to depth skew.
|
||||
*/
|
||||
addSkew() {
|
||||
this.#skew++;
|
||||
}
|
||||
}
|
||||
|
||||
export function getWorldInfoSettings() {
|
||||
return {
|
||||
@ -790,6 +919,9 @@ const originalDataKeyMap = {
|
||||
'key': 'keys',
|
||||
'keysecondary': 'secondary_keys',
|
||||
'selective': 'selective',
|
||||
'matchWholeWords': 'extensions.match_whole_words',
|
||||
'caseSensitive': 'extensions.case_sensitive',
|
||||
'scanDepth': 'extensions.scan_depth',
|
||||
};
|
||||
|
||||
function setOriginalDataValue(data, uid, key, value) {
|
||||
@ -1167,7 +1299,7 @@ function getWorldEntry(name, data, entry) {
|
||||
probabilityInput.data('uid', entry.uid);
|
||||
probabilityInput.on('input', function () {
|
||||
const uid = $(this).data('uid');
|
||||
const value = parseInt($(this).val());
|
||||
const value = Number($(this).val());
|
||||
|
||||
data.entries[uid].probability = !isNaN(value) ? value : null;
|
||||
|
||||
@ -1370,6 +1502,57 @@ function getWorldEntry(name, data, entry) {
|
||||
updateEditor(navigation_option.previous);
|
||||
});
|
||||
|
||||
// scan depth
|
||||
const scanDepthInput = template.find('input[name="scanDepth"]');
|
||||
scanDepthInput.data('uid', entry.uid);
|
||||
scanDepthInput.on('input', function () {
|
||||
const uid = $(this).data('uid');
|
||||
const isEmpty = $(this).val() === '';
|
||||
const value = Number($(this).val());
|
||||
|
||||
// Clamp if necessary
|
||||
if (value < 0) {
|
||||
$(this).val(0).trigger('input');
|
||||
return;
|
||||
}
|
||||
|
||||
if (value > MAX_SCAN_DEPTH) {
|
||||
$(this).val(MAX_SCAN_DEPTH).trigger('input');
|
||||
return;
|
||||
}
|
||||
|
||||
data.entries[uid].scanDepth = !isEmpty && !isNaN(value) && value >= 0 && value < MAX_SCAN_DEPTH ? Math.floor(value) : null;
|
||||
setOriginalDataValue(data, uid, 'extensions.scan_depth', data.entries[uid].scanDepth);
|
||||
saveWorldInfo(name, data);
|
||||
});
|
||||
scanDepthInput.val(entry.scanDepth ?? null).trigger('input');
|
||||
|
||||
// case sensitive select
|
||||
const caseSensitiveSelect = template.find('select[name="caseSensitive"]');
|
||||
caseSensitiveSelect.data('uid', entry.uid);
|
||||
caseSensitiveSelect.on('input', function () {
|
||||
const uid = $(this).data('uid');
|
||||
const value = $(this).val();
|
||||
|
||||
data.entries[uid].caseSensitive = value === 'null' ? null : value === 'true';
|
||||
setOriginalDataValue(data, uid, 'extensions.case_sensitive', data.entries[uid].caseSensitive);
|
||||
saveWorldInfo(name, data);
|
||||
});
|
||||
caseSensitiveSelect.val((entry.caseSensitive === null || entry.caseSensitive === undefined) ? 'null' : entry.caseSensitive ? 'true' : 'false').trigger('input');
|
||||
|
||||
// match whole words select
|
||||
const matchWholeWordsSelect = template.find('select[name="matchWholeWords"]');
|
||||
matchWholeWordsSelect.data('uid', entry.uid);
|
||||
matchWholeWordsSelect.on('input', function () {
|
||||
const uid = $(this).data('uid');
|
||||
const value = $(this).val();
|
||||
|
||||
data.entries[uid].matchWholeWords = value === 'null' ? null : value === 'true';
|
||||
setOriginalDataValue(data, uid, 'extensions.match_whole_words', data.entries[uid].matchWholeWords);
|
||||
saveWorldInfo(name, data);
|
||||
});
|
||||
matchWholeWordsSelect.val((entry.matchWholeWords === null || entry.matchWholeWords === undefined) ? 'null' : entry.matchWholeWords ? 'true' : 'false').trigger('input');
|
||||
|
||||
template.find('.inline-drawer-content').css('display', 'none'); //entries start collapsed
|
||||
|
||||
function updatePosOrdDisplay(uid) {
|
||||
@ -1428,6 +1611,9 @@ const newEntryTemplate = {
|
||||
useProbability: true,
|
||||
depth: DEFAULT_DEPTH,
|
||||
group: '',
|
||||
scanDepth: null,
|
||||
caseSensitive: null,
|
||||
matchWholeWords: null,
|
||||
};
|
||||
|
||||
function createWorldInfoEntry(name, data, fromSlashCommand = false) {
|
||||
@ -1585,11 +1771,6 @@ async function createNewWorldInfo(worldInfoName) {
|
||||
}
|
||||
}
|
||||
|
||||
// Gets a string that respects the case sensitivity setting
|
||||
function transformString(str) {
|
||||
return world_info_case_sensitive ? str : str.toLowerCase();
|
||||
}
|
||||
|
||||
async function getCharacterLore() {
|
||||
const character = characters[this_chid];
|
||||
const name = character?.name;
|
||||
@ -1711,11 +1892,10 @@ async function getSortedEntries() {
|
||||
|
||||
async function checkWorldInfo(chat, maxContext) {
|
||||
const context = getContext();
|
||||
const messagesToLookBack = world_info_depth * 2 || 1;
|
||||
const buffer = new WorldInfoBuffer(chat);
|
||||
|
||||
// Combine the chat
|
||||
let textToScan = chat.slice(0, messagesToLookBack).join('');
|
||||
let minActivationMsgIndex = messagesToLookBack; // tracks chat index to satisfy `world_info_min_activations`
|
||||
let minActivationMsgIndex = world_info_depth; // tracks chat index to satisfy `world_info_min_activations`
|
||||
|
||||
// Add the depth or AN if enabled
|
||||
// Put this code here since otherwise, the chat reference is modified
|
||||
@ -1723,14 +1903,11 @@ async function checkWorldInfo(chat, maxContext) {
|
||||
if (context.extensionPrompts[key]?.scan) {
|
||||
const prompt = getExtensionPromptByName(key);
|
||||
if (prompt) {
|
||||
textToScan = `${prompt}\n${textToScan}`;
|
||||
buffer.addRecurse(prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transform the resulting string
|
||||
textToScan = transformString(textToScan);
|
||||
|
||||
let needsToScan = true;
|
||||
let token_budget_overflowed = false;
|
||||
let count = 0;
|
||||
@ -1809,10 +1986,11 @@ async function checkWorldInfo(chat, maxContext) {
|
||||
|
||||
primary: for (let key of entry.key) {
|
||||
const substituted = substituteParams(key);
|
||||
const textToScan = buffer.get(entry);
|
||||
|
||||
console.debug(`${entry.uid}: ${substituted}`);
|
||||
|
||||
if (substituted && matchKeys(textToScan, substituted.trim())) {
|
||||
if (substituted && buffer.matchKeys(textToScan, substituted.trim(), entry)) {
|
||||
console.debug(`WI UID ${entry.uid} found by primary match: ${substituted}.`);
|
||||
|
||||
//selective logic begins
|
||||
@ -1826,7 +2004,7 @@ async function checkWorldInfo(chat, maxContext) {
|
||||
let hasAllMatch = true;
|
||||
secondary: for (let keysecondary of entry.keysecondary) {
|
||||
const secondarySubstituted = substituteParams(keysecondary);
|
||||
const hasSecondaryMatch = secondarySubstituted && matchKeys(textToScan, secondarySubstituted.trim());
|
||||
const hasSecondaryMatch = secondarySubstituted && buffer.matchKeys(textToScan, secondarySubstituted.trim(), entry);
|
||||
console.debug(`WI UID:${entry.uid}: Filtering for secondary keyword - "${secondarySubstituted}".`);
|
||||
|
||||
if (hasSecondaryMatch) {
|
||||
@ -1926,9 +2104,8 @@ async function checkWorldInfo(chat, maxContext) {
|
||||
.filter(x => !failedProbabilityChecks.has(x))
|
||||
.filter(x => !x.preventRecursion)
|
||||
.map(x => x.content).join('\n');
|
||||
const currentlyActivatedText = transformString(text);
|
||||
textToScan = (currentlyActivatedText + '\n' + textToScan);
|
||||
allActivatedText = (currentlyActivatedText + '\n' + allActivatedText);
|
||||
buffer.addRecurse(text);
|
||||
allActivatedText = (text + '\n' + allActivatedText);
|
||||
}
|
||||
|
||||
// world_info_min_activations
|
||||
@ -1941,8 +2118,8 @@ async function checkWorldInfo(chat, maxContext) {
|
||||
) || (minActivationMsgIndex >= chat.length);
|
||||
if (!over_max) {
|
||||
needsToScan = true;
|
||||
textToScan = transformString(chat.slice(minActivationMsgIndex, minActivationMsgIndex + 1).join(''));
|
||||
minActivationMsgIndex += 1;
|
||||
buffer.addSkew();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2069,29 +2246,6 @@ function filterByInclusionGroups(newEntries, allActivatedEntries) {
|
||||
}
|
||||
}
|
||||
|
||||
function matchKeys(haystack, needle) {
|
||||
const transformedString = transformString(needle);
|
||||
|
||||
if (world_info_match_whole_words) {
|
||||
const keyWords = transformedString.split(/\s+/);
|
||||
|
||||
if (keyWords.length > 1) {
|
||||
return haystack.includes(transformedString);
|
||||
}
|
||||
else {
|
||||
const regex = new RegExp(`\\b${escapeRegex(transformedString)}\\b`);
|
||||
if (regex.test(haystack)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
return haystack.includes(transformedString);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function convertAgnaiMemoryBook(inputObj) {
|
||||
const outputObj = { entries: {} };
|
||||
|
||||
@ -2210,6 +2364,9 @@ function convertCharacterBook(characterBook) {
|
||||
depth: entry.extensions?.depth ?? DEFAULT_DEPTH,
|
||||
selectiveLogic: entry.extensions?.selectiveLogic ?? world_info_logic.AND_ANY,
|
||||
group: entry.extensions?.group ?? '',
|
||||
scanDepth: entry.extensions?.scan_depth ?? null,
|
||||
caseSensitive: entry.extensions?.case_sensitive ?? null,
|
||||
matchWholeWords: entry.extensions?.match_whole_words ?? null,
|
||||
};
|
||||
});
|
||||
|
||||
@ -2245,7 +2402,7 @@ export function checkEmbeddedWorld(chid) {
|
||||
const checkKey = `AlertWI_${characters[chid].avatar}`;
|
||||
const worldName = characters[chid]?.data?.extensions?.world;
|
||||
if (!localStorage.getItem(checkKey) && (!worldName || !world_names.includes(worldName))) {
|
||||
localStorage.setItem(checkKey, 1);
|
||||
localStorage.setItem(checkKey, 'true');
|
||||
|
||||
if (power_user.world_import_dialog) {
|
||||
const html = `<h3>This character has an embedded World/Lorebook.</h3>
|
||||
|
@ -538,7 +538,6 @@ hr {
|
||||
background-color: var(--SmartThemeChatTintColor);
|
||||
-webkit-backdrop-filter: blur(var(--SmartThemeBlurStrength));
|
||||
text-shadow: 0px 0px calc(var(--shadowWidth) * 1px) var(--SmartThemeShadowColor);
|
||||
scrollbar-width: thin;
|
||||
flex-direction: column;
|
||||
z-index: 30;
|
||||
}
|
||||
@ -979,7 +978,6 @@ textarea {
|
||||
font-size: var(--mainFontSize);
|
||||
font-family: "Noto Sans", "Noto Color Emoji", sans-serif;
|
||||
padding: 5px 10px;
|
||||
scrollbar-width: thin;
|
||||
max-height: 90vh;
|
||||
max-height: 90svh;
|
||||
}
|
||||
@ -3125,7 +3123,6 @@ a {
|
||||
box-shadow: none;
|
||||
border-radius: 10px;
|
||||
overflow: hidden;
|
||||
scrollbar-width: thin;
|
||||
flex-flow: column;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
const TASK = 'feature-extraction';
|
||||
|
||||
/**
|
||||
* Gets the vectorized text in form of an array of numbers.
|
||||
* @param {string} text - The text to vectorize
|
||||
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||
*/
|
||||
@ -12,6 +13,20 @@ async function getTransformersVector(text) {
|
||||
return vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vectorized texts in form of an array of arrays of numbers.
|
||||
* @param {string[]} texts - The texts to vectorize
|
||||
* @returns {Promise<number[][]>} - The vectorized texts in form of an array of arrays of numbers
|
||||
*/
|
||||
async function getTransformersBatchVector(texts) {
|
||||
const result = [];
|
||||
for (const text of texts) {
|
||||
result.push(await getTransformersVector(text));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getTransformersVector,
|
||||
getTransformersBatchVector,
|
||||
};
|
||||
|
@ -831,7 +831,7 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
let json = await fetchResponse.json();
|
||||
response.send(json);
|
||||
console.log(json);
|
||||
console.log(json?.choices[0]?.message);
|
||||
console.log(json?.choices?.[0]?.message);
|
||||
} else if (fetchResponse.status === 429 && retries > 0) {
|
||||
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
|
||||
setTimeout(() => {
|
||||
|
@ -388,6 +388,9 @@ function convertWorldInfoToCharacterBook(name, entries) {
|
||||
selectiveLogic: entry.selectiveLogic ?? 0,
|
||||
group: entry.group ?? '',
|
||||
prevent_recursion: entry.preventRecursion ?? false,
|
||||
scan_depth: entry.scanDepth ?? null,
|
||||
match_whole_words: entry.matchWholeWords ?? null,
|
||||
case_sensitive: entry.caseSensitive ?? null,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -27,6 +27,26 @@ async function getVector(source, sourceSettings, text) {
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text batch from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getBatchVector(source, texts) {
|
||||
switch (source) {
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIBatchVector(texts, source);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersBatchVector(texts);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteBatchVector(texts);
|
||||
}
|
||||
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the index for the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
@ -56,12 +76,12 @@ async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||
|
||||
await store.beginUpdate();
|
||||
|
||||
for (const item of items) {
|
||||
const text = item.text;
|
||||
const hash = item.hash;
|
||||
const index = item.index;
|
||||
const vector = await getVector(source, sourceSettings, text);
|
||||
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
|
||||
const vectors = await getBatchVector(source, items.map(x => x.text));
|
||||
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const item = items[i];
|
||||
const vector = vectors[i];
|
||||
await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } });
|
||||
}
|
||||
|
||||
await store.endUpdate();
|
||||
|
@ -1,6 +1,17 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from gecko model
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getMakerSuiteBatchVector(texts) {
|
||||
const promises = texts.map(text => getMakerSuiteVector(text));
|
||||
const vectors = await Promise.all(promises);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from PaLM gecko model
|
||||
* @param {string} text - The text to get the vector for
|
||||
@ -40,4 +51,5 @@ async function getMakerSuiteVector(text) {
|
||||
|
||||
module.exports = {
|
||||
getMakerSuiteVector,
|
||||
getMakerSuiteBatchVector,
|
||||
};
|
||||
|
@ -3,7 +3,7 @@ const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'mistral': {
|
||||
secretKey: SECRET_KEYS.MISTRAL,
|
||||
secretKey: SECRET_KEYS.MISTRALAI,
|
||||
url: 'api.mistral.ai',
|
||||
model: 'mistral-embed',
|
||||
},
|
||||
@ -15,12 +15,12 @@ const SOURCES = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
async function getOpenAIBatchVector(texts, source) {
|
||||
const config = SOURCES[source];
|
||||
|
||||
if (!config) {
|
||||
@ -43,7 +43,7 @@ async function getOpenAIVector(text, source) {
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: text,
|
||||
input: texts,
|
||||
model: config.model,
|
||||
}),
|
||||
});
|
||||
@ -55,16 +55,31 @@ async function getOpenAIVector(text, source) {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data?.data[0]?.embedding;
|
||||
|
||||
if (!Array.isArray(vector)) {
|
||||
if (!Array.isArray(data?.data)) {
|
||||
console.log('API response was not an array');
|
||||
throw new Error('API response was not an array');
|
||||
}
|
||||
|
||||
return vector;
|
||||
// Sort data by x.index to ensure the order is correct
|
||||
data.data.sort((a, b) => a.index - b.index);
|
||||
|
||||
const vectors = data.data.map(x => x.embedding);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
const vectors = await getOpenAIBatchVector([text], source);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIVector,
|
||||
getOpenAIBatchVector,
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user