1
0
mirror of https://git.sr.ht/~tsileo/microblog.pub synced 2025-06-05 21:59:23 +02:00

Media proxy cleanup

This commit is contained in:
Thomas Sileo
2022-07-19 08:12:49 +02:00
parent 66a9778995
commit 9882fc555c
2 changed files with 52 additions and 55 deletions

View File

@@ -9,6 +9,7 @@ from typing import MutableMapping
from typing import Type
import httpx
import starlette
from asgiref.typing import ASGI3Application
from asgiref.typing import ASGIReceiveCallable
from asgiref.typing import ASGISendCallable
@@ -57,7 +58,6 @@ from app.config import DOMAIN
from app.config import ID
from app.config import USER_AGENT
from app.config import USERNAME
from app.config import generate_csrf_token
from app.config import is_activitypub_requested
from app.config import verify_csrf_token
from app.database import AsyncSession
@@ -76,6 +76,7 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
# TODO(ts):
#
# Next:
# - Article support
# - indieauth tweaks
# - API for posting notes
# - allow to block servers
@@ -390,7 +391,6 @@ async def following(
.all()
)
# TODO: support next_cursor/prev_cursor
actors_metadata = {}
if is_current_user_admin(request):
actors_metadata = await get_actors_metadata(
@@ -482,13 +482,17 @@ async def _check_outbox_object_acl(
ap.VisibilityEnum.UNLISTED,
]:
return None
elif ap_object.visibility == ap.VisibilityEnum.FOLLOWERS_ONLY:
# Is the signing actor a follower?
followers = await boxes.fetch_actor_collection(
db_session, BASE_URL + "/followers"
)
if httpsig_info.signed_by_ap_actor_id in [actor.ap_id for actor in followers]:
return None
elif ap_object.visibility == ap.VisibilityEnum.DIRECT:
# Is the signing actor targeted in the object audience?
audience = ap_object.ap_object.get("to", []) + ap_object.ap_object.get("cc", [])
if httpsig_info.signed_by_ap_actor_id in audience:
return None
@@ -718,7 +722,7 @@ async def get_remote_follow(
db_session,
request,
"remote_follow.html",
{"remote_follow_csrf_token": generate_csrf_token()},
{},
)
@@ -733,6 +737,7 @@ async def post_remote_follow(
remote_follow_template = await get_remote_follow_template(profile)
if not remote_follow_template:
# TODO(ts): error message to user
raise HTTPException(status_code=404)
return RedirectResponse(
@@ -812,12 +817,9 @@ async def nodeinfo(
proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True)
@app.get("/proxy/media/{encoded_url}")
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
# Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
async def _proxy_get(
request: starlette.requests.Request, url: str, stream: bool
) -> httpx.Response:
# Request the URL (and filter request headers)
proxy_req = proxy_client.build_request(
request.method,
@@ -830,27 +832,42 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp
]
+ [(b"user-agent", USER_AGENT.encode())],
)
proxy_resp = await proxy_client.send(proxy_req, stream=True)
# Filter the headers
proxy_resp_headers = [
(k, v)
for (k, v) in proxy_resp.headers.items()
if k.lower()
in [
"content-length",
"content-type",
"content-range",
"accept-ranges" "etag",
"cache-control",
"expires",
"date",
"last-modified",
]
]
return await proxy_client.send(proxy_req, stream=stream)
def _filter_proxy_resp_headers(
proxy_resp: httpx.Response,
allowed_headers: list[str],
) -> dict[str, str]:
return {
k: v for (k, v) in proxy_resp.headers.items() if k.lower() in allowed_headers
}
@app.get("/proxy/media/{encoded_url}")
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
# Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
proxy_resp = await _proxy_get(request, url, stream=True)
return StreamingResponse(
proxy_resp.aiter_raw(),
status_code=proxy_resp.status_code,
headers=dict(proxy_resp_headers),
headers=_filter_proxy_resp_headers(
proxy_resp,
[
"content-length",
"content-type",
"content-range",
"accept-ranges" "etag",
"cache-control",
"expires",
"date",
"last-modified",
],
),
background=BackgroundTask(proxy_resp.aclose),
)
@@ -876,25 +893,7 @@ async def serve_proxy_media_resized(
headers=resp_headers,
)
# Request the URL (and filter request headers)
async with httpx.AsyncClient() as client:
proxy_resp = await client.get(
url,
headers=[
(k, v)
for (k, v) in request.headers.raw
if k.lower()
not in [
b"host",
b"cookie",
b"x-forwarded-for",
b"x-real-ip",
b"user-agent",
]
]
+ [(b"user-agent", USER_AGENT.encode())],
follow_redirects=True,
)
proxy_resp = await _proxy_get(request, url, stream=False)
if proxy_resp.status_code != 200:
return PlainTextResponse(
proxy_resp.content,
@@ -902,18 +901,16 @@ async def serve_proxy_media_resized(
)
# Filter the headers
proxy_resp_headers = {
k: v
for (k, v) in proxy_resp.headers.items()
if k.lower()
in [
proxy_resp_headers = _filter_proxy_resp_headers(
proxy_resp,
[
"content-type",
"etag",
"cache-control",
"expires",
"last-modified",
]
}
],
)
try:
out = BytesIO(proxy_resp.content)