mirror of
				https://git.sr.ht/~tsileo/microblog.pub
				synced 2025-06-05 21:59:23 +02:00 
			
		
		
		
	Tweak middleware
This commit is contained in:
		| @@ -145,8 +145,7 @@ async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "Actor | |||||||
|         handle=_handle(ap_actor), |         handle=_handle(ap_actor), | ||||||
|     ) |     ) | ||||||
|     db_session.add(actor) |     db_session.add(actor) | ||||||
|     await db_session.commit() |     await db_session.flush() | ||||||
|     await db_session.refresh(actor) |  | ||||||
|     return actor |     return actor | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										43
									
								
								app/main.py
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								app/main.py
									
									
									
									
									
								
							| @@ -92,54 +92,29 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac | |||||||
| class CustomMiddleware: | class CustomMiddleware: | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         app: "ASGI3Application", |         app: ASGI3Application, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.app = app |         self.app = app | ||||||
|  |  | ||||||
|     async def __call__( |     async def __call__( | ||||||
|         self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable |         self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         """ |         # We only care about HTTP requests | ||||||
|         if scope["type"] in ("http", "websocket"): |  | ||||||
|             scope = cast(HTTPScope | WebSocketScope, scope) |  | ||||||
|             client_addr: tuple[str, int] | None = scope.get("client") |  | ||||||
|             client_host = client_addr[0] if client_addr else None |  | ||||||
|  |  | ||||||
|             if self.always_trust or client_host in self.trusted_hosts: |  | ||||||
|                 headers = dict(scope["headers"]) |  | ||||||
|  |  | ||||||
|                 if b"x-forwarded-proto" in headers: |  | ||||||
|                     # Determine if the incoming request was http or https based on |  | ||||||
|                     # the X-Forwarded-Proto header. |  | ||||||
|                     x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1") |  | ||||||
|                     scope["scheme"] = x_forwarded_proto.strip()  # type: ignore[index] |  | ||||||
|  |  | ||||||
|                 if b"x-forwarded-for" in headers: |  | ||||||
|                     # Determine the client address from the last trusted IP in the |  | ||||||
|                     # X-Forwarded-For header. We've lost the connecting client's port |  | ||||||
|                     # information by now, so only include the host. |  | ||||||
|                     x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") |  | ||||||
|                     x_forwarded_for_hosts = [ |  | ||||||
|                         item.strip() for item in x_forwarded_for.split(",") |  | ||||||
|                     ] |  | ||||||
|                     host = self.get_trusted_client_host(x_forwarded_for_hosts) |  | ||||||
|                     port = 0 |  | ||||||
|                     scope["client"] = (host, port)  # type: ignore[arg-type] |  | ||||||
|         """ |  | ||||||
|  |  | ||||||
|         if scope["type"] != "http": |         if scope["type"] != "http": | ||||||
|             await self.app(scope, receive, send) |             await self.app(scope, receive, send) | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         instance = {"http_status_code": None} |         response_details = {} | ||||||
|  |  | ||||||
|         start_time = time.perf_counter() |         start_time = time.perf_counter() | ||||||
|         request_id = os.urandom(8).hex() |         request_id = os.urandom(8).hex() | ||||||
|  |  | ||||||
|         async def send_wrapper(message: Message) -> None: |         async def send_wrapper(message: Message) -> None: | ||||||
|             if message["type"] == "http.response.start": |             if message["type"] == "http.response.start": | ||||||
|                 instance["http_status_code"] = message["status"] |  | ||||||
|  |  | ||||||
|  |                 # Extract the HTTP response status code | ||||||
|  |                 response_details["status_code"] = message["status"] | ||||||
|  |  | ||||||
|  |                 # And add the security headers | ||||||
|                 headers = MutableHeaders(scope=message) |                 headers = MutableHeaders(scope=message) | ||||||
|                 headers["X-Request-ID"] = request_id |                 headers["X-Request-ID"] = request_id | ||||||
|                 headers["Server"] = "microblogpub" |                 headers["Server"] = "microblogpub" | ||||||
| @@ -160,6 +135,8 @@ class CustomMiddleware: | |||||||
|  |  | ||||||
|             await send(message)  # type: ignore |             await send(message)  # type: ignore | ||||||
|  |  | ||||||
|  |         # Make loguru ouput the request ID on every log statement within | ||||||
|  |         # the request | ||||||
|         with logger.contextualize(request_id=request_id): |         with logger.contextualize(request_id=request_id): | ||||||
|             client_host, client_port = scope["client"]  # type: ignore |             client_host, client_port = scope["client"]  # type: ignore | ||||||
|             scheme = scope["scheme"] |             scheme = scope["scheme"] | ||||||
| @@ -175,7 +152,7 @@ class CustomMiddleware: | |||||||
|             finally: |             finally: | ||||||
|                 elapsed_time = time.perf_counter() - start_time |                 elapsed_time = time.perf_counter() - start_time | ||||||
|                 logger.info( |                 logger.info( | ||||||
|                     f"status_code={instance['http_status_code']} " |                     f"status_code={response_details['status_code']} " | ||||||
|                     f"{elapsed_time=:.2f}s" |                     f"{elapsed_time=:.2f}s" | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user