Add huggingface inference as text completion source

This commit is contained in:
Cohee
2024-06-28 18:17:27 +03:00
parent 6b204ada9f
commit bbb1a6e578
9 changed files with 122 additions and 8 deletions

View File

@@ -147,6 +147,19 @@ function getKoboldCppHeaders(directories) {
}) : {};
}
/**
* Gets the headers for the HuggingFace API.
* @param {import('./users').UserDirectoryList} directories
* @returns {object} Headers for the request
*/
function getHuggingFaceHeaders(directories) {
const apiKey = readSecret(directories, SECRET_KEYS.HUGGINGFACE);
return apiKey ? ({
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
@@ -187,6 +200,7 @@ function setAdditionalHeadersByType(requestHeaders, type, server, directories) {
[TEXTGEN_TYPES.OPENROUTER]: getOpenRouterHeaders,
[TEXTGEN_TYPES.KOBOLDCPP]: getKoboldCppHeaders,
[TEXTGEN_TYPES.LLAMACPP]: getLlamaCppHeaders,
[TEXTGEN_TYPES.HUGGINGFACE]: getHuggingFaceHeaders,
};
const getHeaders = headerGetters[type];

View File

@@ -216,6 +216,7 @@ const TEXTGEN_TYPES = {
INFERMATICAI: 'infermaticai',
DREAMGEN: 'dreamgen',
OPENROUTER: 'openrouter',
HUGGINGFACE: 'huggingface',
};
const INFERMATICAI_KEYS = [

View File

@@ -95,13 +95,14 @@ router.post('/status', jsonParser, async function (request, response) {
setAdditionalHeaders(request, args, baseUrl);
const apiType = request.body.api_type;
let url = baseUrl;
let result = '';
if (request.body.legacy_api) {
url += '/v1/model';
} else {
switch (request.body.api_type) {
switch (apiType) {
case TEXTGEN_TYPES.OOBA:
case TEXTGEN_TYPES.VLLM:
case TEXTGEN_TYPES.APHRODITE:
@@ -126,6 +127,9 @@ router.post('/status', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.OLLAMA:
url += '/api/tags';
break;
case TEXTGEN_TYPES.HUGGINGFACE:
url += '/info';
break;
}
}
@@ -144,14 +148,18 @@ router.post('/status', jsonParser, async function (request, response) {
}
// Rewrap to OAI-like response
if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI && Array.isArray(data)) {
if (apiType === TEXTGEN_TYPES.TOGETHERAI && Array.isArray(data)) {
data = { data: data.map(x => ({ id: x.name, ...x })) };
}
if (request.body.api_type === TEXTGEN_TYPES.OLLAMA && Array.isArray(data.models)) {
if (apiType === TEXTGEN_TYPES.OLLAMA && Array.isArray(data.models)) {
data = { data: data.models.map(x => ({ id: x.name, ...x })) };
}
if (apiType === TEXTGEN_TYPES.HUGGINGFACE) {
data = { data: [] };
}
if (!Array.isArray(data.data)) {
console.log('Models response is not an array.');
return response.status(400);
@@ -163,7 +171,7 @@ router.post('/status', jsonParser, async function (request, response) {
// Set result to the first model ID
result = modelIds[0] || 'Valid';
if (request.body.api_type === TEXTGEN_TYPES.OOBA) {
if (apiType === TEXTGEN_TYPES.OOBA) {
try {
const modelInfoUrl = baseUrl + '/v1/internal/model/info';
const modelInfoReply = await fetch(modelInfoUrl, args);
@@ -178,7 +186,7 @@ router.post('/status', jsonParser, async function (request, response) {
} catch (error) {
console.error(`Failed to get Ooba model info: ${error}`);
}
} else if (request.body.api_type === TEXTGEN_TYPES.TABBY) {
} else if (apiType === TEXTGEN_TYPES.TABBY) {
try {
const modelInfoUrl = baseUrl + '/v1/model';
const modelInfoReply = await fetch(modelInfoUrl, args);
@@ -241,6 +249,7 @@ router.post('/generate', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.TOGETHERAI:
case TEXTGEN_TYPES.INFERMATICAI:
case TEXTGEN_TYPES.HUGGINGFACE:
url += '/v1/completions';
break;
case TEXTGEN_TYPES.DREAMGEN:

View File

@@ -41,6 +41,7 @@ const SECRET_KEYS = {
GROQ: 'api_key_groq',
AZURE_TTS: 'api_key_azure_tts',
ZEROONEAI: 'api_key_01ai',
HUGGINGFACE: 'api_key_huggingface',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false