Switch to raw ASGI middleware

This commit is contained in:
Thomas Sileo 2022-07-14 12:13:23 +02:00
parent dd50db40d9
commit a39f874ad5
4 changed files with 107 additions and 60 deletions

View File

@ -29,4 +29,7 @@ def now() -> datetime.datetime:
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session:
yield session
try:
yield session
finally:
await session.close()

View File

@ -9,6 +9,10 @@ from typing import MutableMapping
from typing import Type
import httpx
from asgiref.typing import ASGI3Application
from asgiref.typing import ASGIReceiveCallable
from asgiref.typing import ASGISendCallable
from asgiref.typing import Scope
from cachetools import LFUCache
from fastapi import Depends
from fastapi import FastAPI
@ -28,7 +32,9 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from starlette.background import BackgroundTask
from starlette.datastructures import MutableHeaders
from starlette.responses import JSONResponse
from starlette.types import Message
from app import activitypub as ap
from app import admin
@ -82,12 +88,107 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
# - [ ] Dockerization
# - [ ] cleanup tasks
class CustomMiddleware:
def __init__(
self,
app: "ASGI3Application",
) -> None:
self.app = app
async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
"""
if scope["type"] in ("http", "websocket"):
scope = cast(HTTPScope | WebSocketScope, scope)
client_addr: tuple[str, int] | None = scope.get("client")
client_host = client_addr[0] if client_addr else None
if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index]
if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
"""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
instance = {"http_status_code": None}
start_time = time.perf_counter()
request_id = os.urandom(8).hex()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
instance["http_status_code"] = message["status"]
headers = MutableHeaders(scope=message)
headers["X-Request-ID"] = request_id
headers["Server"] = "microblogpub"
headers[
"referrer-policy"
] = "no-referrer, strict-origin-when-cross-origin"
headers["x-content-type-options"] = "nosniff"
headers["x-xss-protection"] = "1; mode=block"
headers["x-frame-options"] = "SAMEORIGIN"
# TODO(ts): disallow inline CSS?
headers["content-security-policy"] = (
"default-src 'self'" + " style-src 'self' 'unsafe-inline';"
)
if not DEBUG:
headers[
"strict-transport-security"
] = "max-age=63072000; includeSubdomains"
await send(message) # type: ignore
with logger.contextualize(request_id=request_id):
client_host, client_port = scope["client"] # type: ignore
scheme = scope["scheme"]
server_host, server_port = scope["server"] # type: ignore
request_method = scope["method"]
request_path = scope["path"]
logger.info(
f"{client_host}:{client_port} - "
f"{request_method} {scheme}://{server_host}:{server_port}{request_path}"
)
try:
await self.app(scope, receive, send_wrapper) # type: ignore
finally:
elapsed_time = time.perf_counter() - start_time
logger.info(
f"status_code={instance['http_status_code']} "
f"{elapsed_time=:.2f}s"
)
return None
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="app/static"), name="static")
app.include_router(admin.router, prefix="/admin")
app.include_router(admin.unauthenticated_router, prefix="/admin")
app.include_router(indieauth.router)
app.include_router(webmentions.router)
app.add_middleware(CustomMiddleware)
logger.configure(extra={"request_id": "no_req_id"})
logger.remove()
@ -100,64 +201,6 @@ logger_format = (
logger.add(sys.stdout, format=logger_format)
@app.middleware("http")
async def request_middleware(request, call_next):
start_time = time.perf_counter()
request_id = os.urandom(8).hex()
with logger.contextualize(request_id=request_id):
logger.info(
f"{request.client.host}:{request.client.port} - "
f"{request.method} {request.url}"
)
try:
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
response.headers["Server"] = "microblogpub"
elapsed_time = time.perf_counter() - start_time
logger.info(f"status_code={response.status_code} {elapsed_time=:.2f}s")
return response
except Exception:
logger.exception("Request failed")
raise
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
try:
response = await call_next(request)
except RuntimeError as exc:
# https://github.com/encode/starlette/discussions/1527#discussioncomment-2234702
if await request.is_disconnected() and str(exc) == "No response returned.":
return Response(status_code=204)
response.headers["referrer-policy"] = "no-referrer, strict-origin-when-cross-origin"
response.headers["x-content-type-options"] = "nosniff"
response.headers["x-xss-protection"] = "1; mode=block"
response.headers["x-frame-options"] = "SAMEORIGIN"
if request.url.path.startswith("/admin/login") or (
is_current_user_admin(request)
and not (
request.url.path.startswith("/attachments")
or request.url.path.startswith("/proxy")
or request.url.path.startswith("/static")
)
):
# Prevent caching (to prevent caching CSRF tokens)
response.headers["Cache-Control"] = "private"
# TODO(ts): disallow inline CSS?
if DEBUG:
return response
response.headers["content-security-policy"] = (
"default-src 'self'" + " style-src 'self' 'unsafe-inline';"
)
if not DEBUG:
response.headers[
"strict-transport-security"
] = "max-age=63072000; includeSubdomains"
return response
class ActivityPubResponse(JSONResponse):
media_type = "application/activity+json"

2
poetry.lock generated
View File

@ -1202,7 +1202,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
[metadata]
lock-version = "1.1"
python-versions = "^3.10"
content-hash = "7bc5ba65a004438ac015dcd01c27e1d327dbf491f9f881a48a2a790bb0bbf710"
content-hash = "4353bb98b40254eea5277799de3329b6658e21178a6da44113e78c897c7f140b"
[metadata.files]
aiosqlite = [

View File

@ -40,6 +40,7 @@ aiosqlite = "^0.17.0"
cachetools = "^5.2.0"
humanize = "^4.2.3"
tabulate = "^0.8.10"
asgiref = "^3.5.2"
[tool.poetry.dev-dependencies]
black = "^22.3.0"