mirror of
				https://git.sr.ht/~tsileo/microblog.pub
				synced 2025-06-05 21:59:23 +02:00 
			
		
		
		
	Tweak middleware
This commit is contained in:
		| @@ -145,8 +145,7 @@ async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "Actor | ||||
|         handle=_handle(ap_actor), | ||||
|     ) | ||||
|     db_session.add(actor) | ||||
|     await db_session.commit() | ||||
|     await db_session.refresh(actor) | ||||
|     await db_session.flush() | ||||
|     return actor | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										43
									
								
								app/main.py
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								app/main.py
									
									
									
									
									
								
							| @@ -92,54 +92,29 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac | ||||
| class CustomMiddleware: | ||||
|     def __init__( | ||||
|         self, | ||||
|         app: "ASGI3Application", | ||||
|         app: ASGI3Application, | ||||
|     ) -> None: | ||||
|         self.app = app | ||||
|  | ||||
|     async def __call__( | ||||
|         self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable | ||||
|     ) -> None: | ||||
|         """ | ||||
|         if scope["type"] in ("http", "websocket"): | ||||
|             scope = cast(HTTPScope | WebSocketScope, scope) | ||||
|             client_addr: tuple[str, int] | None = scope.get("client") | ||||
|             client_host = client_addr[0] if client_addr else None | ||||
|  | ||||
|             if self.always_trust or client_host in self.trusted_hosts: | ||||
|                 headers = dict(scope["headers"]) | ||||
|  | ||||
|                 if b"x-forwarded-proto" in headers: | ||||
|                     # Determine if the incoming request was http or https based on | ||||
|                     # the X-Forwarded-Proto header. | ||||
|                     x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1") | ||||
|                     scope["scheme"] = x_forwarded_proto.strip()  # type: ignore[index] | ||||
|  | ||||
|                 if b"x-forwarded-for" in headers: | ||||
|                     # Determine the client address from the last trusted IP in the | ||||
|                     # X-Forwarded-For header. We've lost the connecting client's port | ||||
|                     # information by now, so only include the host. | ||||
|                     x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") | ||||
|                     x_forwarded_for_hosts = [ | ||||
|                         item.strip() for item in x_forwarded_for.split(",") | ||||
|                     ] | ||||
|                     host = self.get_trusted_client_host(x_forwarded_for_hosts) | ||||
|                     port = 0 | ||||
|                     scope["client"] = (host, port)  # type: ignore[arg-type] | ||||
|         """ | ||||
|  | ||||
|         # We only care about HTTP requests | ||||
|         if scope["type"] != "http": | ||||
|             await self.app(scope, receive, send) | ||||
|             return | ||||
|  | ||||
|         instance = {"http_status_code": None} | ||||
|  | ||||
|         response_details = {} | ||||
|         start_time = time.perf_counter() | ||||
|         request_id = os.urandom(8).hex() | ||||
|  | ||||
|         async def send_wrapper(message: Message) -> None: | ||||
|             if message["type"] == "http.response.start": | ||||
|                 instance["http_status_code"] = message["status"] | ||||
|  | ||||
|                 # Extract the HTTP response status code | ||||
|                 response_details["status_code"] = message["status"] | ||||
|  | ||||
|                 # And add the security headers | ||||
|                 headers = MutableHeaders(scope=message) | ||||
|                 headers["X-Request-ID"] = request_id | ||||
|                 headers["Server"] = "microblogpub" | ||||
| @@ -160,6 +135,8 @@ class CustomMiddleware: | ||||
|  | ||||
|             await send(message)  # type: ignore | ||||
|  | ||||
|         # Make loguru ouput the request ID on every log statement within | ||||
|         # the request | ||||
|         with logger.contextualize(request_id=request_id): | ||||
|             client_host, client_port = scope["client"]  # type: ignore | ||||
|             scheme = scope["scheme"] | ||||
| @@ -175,7 +152,7 @@ class CustomMiddleware: | ||||
|             finally: | ||||
|                 elapsed_time = time.perf_counter() - start_time | ||||
|                 logger.info( | ||||
|                     f"status_code={instance['http_status_code']} " | ||||
|                     f"status_code={response_details['status_code']} " | ||||
|                     f"{elapsed_time=:.2f}s" | ||||
|                 ) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user