From bd223486dec85d083a207e461a6391c0064403f1 Mon Sep 17 00:00:00 2001
From: Cohee <18619528+Cohee1207@users.noreply.github.com>
Date: Thu, 14 Mar 2024 00:48:08 +0200
Subject: [PATCH] Include additional headers for all supported Text Completion
 types.

---
 default/config.yaml         |  4 ++++
 src/additional-headers.js   | 15 ++++++++++++++-
 src/endpoints/tokenizers.js |  2 +-
 3 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/default/config.yaml b/default/config.yaml
index 5925d573b..dedb5ac5f 100644
--- a/default/config.yaml
+++ b/default/config.yaml
@@ -35,11 +35,15 @@ skipContentCheck: false
 # Disable automatic chats backup
 disableChatBackup: false
 # API request overrides (for KoboldAI and Text Completion APIs)
+## Note: host includes the port number if it's not the default (80 or 443)
 ## Format is an array of objects:
 ## - hosts:
 ##   - example.com
 ##   headers:
 ##     Content-Type: application/json
+##   - 127.0.0.1:5001
+##   headers:
+##     User-Agent: "Googlebot/2.1 (+http://www.google.com/bot.html)"
 requestOverrides: []
 # -- PLUGIN CONFIGURATION --
 # Enable UI extensions
diff --git a/src/additional-headers.js b/src/additional-headers.js
index e7480f03e..e69872bf3 100644
--- a/src/additional-headers.js
+++ b/src/additional-headers.js
@@ -124,10 +124,23 @@ function setAdditionalHeaders(request, args, server) {
             headers = getKoboldCppHeaders();
             break;
         default:
-            headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
+            headers = {};
             break;
     }
 
+    if (typeof server === 'string' && server.length > 0) {
+        try {
+            const url = new URL(server);
+            const overrideHeaders =  getOverrideHeaders(url.host);
+
+            if (overrideHeaders && Object.keys(overrideHeaders).length > 0) {
+                Object.assign(headers, overrideHeaders);
+            }
+        } catch {
+            // Do nothing
+        }
+    }
+
     Object.assign(args.headers, headers);
 }
 
diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js
index 1ab7d77b8..615042a96 100644
--- a/src/endpoints/tokenizers.js
+++ b/src/endpoints/tokenizers.js
@@ -607,7 +607,7 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re
             headers: { 'Content-Type': 'application/json' },
         };
 
-        setAdditionalHeaders(request, args, null);
+        setAdditionalHeaders(request, args, baseUrl);
 
         // Convert to string + remove trailing slash + /v1 suffix
         let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');