diff --git a/src/BirdsiteLive.Common/Settings/InstanceSettings.cs b/src/BirdsiteLive.Common/Settings/InstanceSettings.cs index 0f701ff..4526c1b 100644 --- a/src/BirdsiteLive.Common/Settings/InstanceSettings.cs +++ b/src/BirdsiteLive.Common/Settings/InstanceSettings.cs @@ -16,5 +16,6 @@ public int FailingFollowerCleanUpThreshold { get; set; } = -1; public int UserCacheCapacity { get; set; } + public string IpWhiteListing { get; set; } } } diff --git a/src/BirdsiteLive/Middlewares/IpWhitelistingMiddleware.cs b/src/BirdsiteLive/Middlewares/IpWhitelistingMiddleware.cs new file mode 100644 index 0000000..1846d8f --- /dev/null +++ b/src/BirdsiteLive/Middlewares/IpWhitelistingMiddleware.cs @@ -0,0 +1,68 @@ +using BirdsiteLive.Common.Settings; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; + +namespace BirdsiteLive.Middlewares +{ + public class IpWhitelistingMiddleware + { + private readonly RequestDelegate _next; + private readonly ILogger _logger; + private readonly byte[][] _safelist; + private readonly bool _ipWhitelistingSet; + + public IpWhitelistingMiddleware( + RequestDelegate next, + ILogger logger, + InstanceSettings instanceSettings) + { + if (!string.IsNullOrWhiteSpace(instanceSettings.IpWhiteListing)) + { + var ips = instanceSettings.IpWhiteListing.Split(';'); + _safelist = new byte[ips.Length][]; + for (var i = 0; i < ips.Length; i++) + { + _safelist[i] = IPAddress.Parse(ips[i]).GetAddressBytes(); + } + _ipWhitelistingSet = true; + } + + _next = next; + _logger = logger; + } + + public async Task Invoke(HttpContext context) + { + //if (context.Request.Method != HttpMethod.Get.Method) + if (_ipWhitelistingSet) + { + var remoteIp = context.Connection.RemoteIpAddress; + _logger.LogDebug("Request from Remote IP address: {RemoteIp}", remoteIp); + + var bytes = remoteIp.GetAddressBytes(); + var badIp = true; + foreach (var address in _safelist) + { + if (address.SequenceEqual(bytes)) + { + badIp = false; + break; + } + } + + if (badIp) + { + _logger.LogWarning("Forbidden Request from Remote IP address: {RemoteIp}", remoteIp); + context.Response.StatusCode = (int)HttpStatusCode.NotFound; + return; + } + } + + await _next.Invoke(context); + } + } +} \ No newline at end of file diff --git a/src/BirdsiteLive/Startup.cs b/src/BirdsiteLive/Startup.cs index 16c7ee4..e088a56 100644 --- a/src/BirdsiteLive/Startup.cs +++ b/src/BirdsiteLive/Startup.cs @@ -8,6 +8,7 @@ using BirdsiteLive.Common.Structs; using BirdsiteLive.DAL.Contracts; using BirdsiteLive.DAL.Postgres.DataAccessLayers; using BirdsiteLive.DAL.Postgres.Settings; +using BirdsiteLive.Middlewares; using BirdsiteLive.Models; using BirdsiteLive.Twitter; using BirdsiteLive.Twitter.Tools; @@ -18,6 +19,7 @@ using Microsoft.AspNetCore.HttpsPolicy; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; namespace BirdsiteLive { @@ -132,6 +134,9 @@ namespace BirdsiteLive app.UseAuthorization(); + var instanceSettings = Configuration.GetSection("Instance").Get(); + app.UseMiddleware(instanceSettings); + app.UseEndpoints(endpoints => { endpoints.MapControllerRoute(