mirror of
				https://git.sr.ht/~tsileo/microblog.pub
				synced 2025-06-05 21:59:23 +02:00 
			
		
		
		
	Add support for custom webfinger domain
This commit is contained in:
		
							
								
								
									
										52
									
								
								app/actor.py
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								app/actor.py
									
									
									
									
									
								
							| @@ -6,6 +6,7 @@ from functools import cached_property | |||||||
| from typing import Union | from typing import Union | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
|  |  | ||||||
|  | import httpx | ||||||
| from loguru import logger | from loguru import logger | ||||||
| from sqlalchemy import select | from sqlalchemy import select | ||||||
| from sqlalchemy.orm import joinedload | from sqlalchemy.orm import joinedload | ||||||
| @@ -13,6 +14,9 @@ from sqlalchemy.orm import joinedload | |||||||
| from app import activitypub as ap | from app import activitypub as ap | ||||||
| from app import media | from app import media | ||||||
| from app.config import BASE_URL | from app.config import BASE_URL | ||||||
|  | from app.config import USER_AGENT | ||||||
|  | from app.config import USERNAME | ||||||
|  | from app.config import WEBFINGER_DOMAIN | ||||||
| from app.database import AsyncSession | from app.database import AsyncSession | ||||||
| from app.utils.datetime import as_utc | from app.utils.datetime import as_utc | ||||||
| from app.utils.datetime import now | from app.utils.datetime import now | ||||||
| @@ -27,7 +31,38 @@ def _handle(raw_actor: ap.RawObject) -> str: | |||||||
|     if not domain.hostname: |     if not domain.hostname: | ||||||
|         raise ValueError(f"Invalid actor ID {ap_id}") |         raise ValueError(f"Invalid actor ID {ap_id}") | ||||||
|  |  | ||||||
|     return f'@{raw_actor["preferredUsername"]}@{domain.hostname}'  # type: ignore |     handle = f'@{raw_actor["preferredUsername"]}@{domain.hostname}'  # type: ignore | ||||||
|  |  | ||||||
|  |     # TODO: cleanup this | ||||||
|  |     # Next, check for custom webfinger domains | ||||||
|  |     resp: httpx.Response | None = None | ||||||
|  |     for url in { | ||||||
|  |         f"https://{domain.hostname}/.well-known/webfinger", | ||||||
|  |         f"https://{domain.hostname}/.well-known/webfinger", | ||||||
|  |     }: | ||||||
|  |         try: | ||||||
|  |             logger.info(f"Webfinger {handle} at {url}") | ||||||
|  |             resp = httpx.get( | ||||||
|  |                 url, | ||||||
|  |                 params={"resource": f"acct:{handle[1:]}"}, | ||||||
|  |                 headers={ | ||||||
|  |                     "User-Agent": USER_AGENT, | ||||||
|  |                 }, | ||||||
|  |                 follow_redirects=True, | ||||||
|  |             ) | ||||||
|  |             resp.raise_for_status() | ||||||
|  |             break | ||||||
|  |         except Exception: | ||||||
|  |             logger.exception(f"Failed to webfinger {handle}") | ||||||
|  |  | ||||||
|  |     if resp: | ||||||
|  |         try: | ||||||
|  |             json_resp = resp.json() | ||||||
|  |             if json_resp.get("subject", "").startswith("acct:"): | ||||||
|  |                 return json_resp["subject"].removeprefix("acct:") | ||||||
|  |         except Exception: | ||||||
|  |             logger.exception(f"Failed to parse webfinger response for {handle}") | ||||||
|  |     return handle | ||||||
|  |  | ||||||
|  |  | ||||||
| class Actor: | class Actor: | ||||||
| @@ -61,7 +96,7 @@ class Actor: | |||||||
|             return self.name |             return self.name | ||||||
|         return self.preferred_username |         return self.preferred_username | ||||||
|  |  | ||||||
|     @property |     @cached_property | ||||||
|     def handle(self) -> str: |     def handle(self) -> str: | ||||||
|         return _handle(self.ap_actor) |         return _handle(self.ap_actor) | ||||||
|  |  | ||||||
| @@ -143,13 +178,18 @@ class Actor: | |||||||
|  |  | ||||||
|  |  | ||||||
| class RemoteActor(Actor): | class RemoteActor(Actor): | ||||||
|     def __init__(self, ap_actor: ap.RawObject) -> None: |     def __init__(self, ap_actor: ap.RawObject, handle: str | None = None) -> None: | ||||||
|         if (ap_type := ap_actor.get("type")) not in ap.ACTOR_TYPES: |         if (ap_type := ap_actor.get("type")) not in ap.ACTOR_TYPES: | ||||||
|             raise ValueError(f"Unexpected actor type: {ap_type}") |             raise ValueError(f"Unexpected actor type: {ap_type}") | ||||||
|  |  | ||||||
|         self._ap_actor = ap_actor |         self._ap_actor = ap_actor | ||||||
|         self._ap_type = ap_type |         self._ap_type = ap_type | ||||||
|  |  | ||||||
|  |         if handle is None: | ||||||
|  |             handle = _handle(ap_actor) | ||||||
|  |  | ||||||
|  |         self._handle = handle | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def ap_actor(self) -> ap.RawObject: |     def ap_actor(self) -> ap.RawObject: | ||||||
|         return self._ap_actor |         return self._ap_actor | ||||||
| @@ -162,8 +202,12 @@ class RemoteActor(Actor): | |||||||
|     def is_from_db(self) -> bool: |     def is_from_db(self) -> bool: | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def handle(self) -> str: | ||||||
|  |         return self._handle | ||||||
|  |  | ||||||
| LOCAL_ACTOR = RemoteActor(ap_actor=ap.ME) |  | ||||||
|  | LOCAL_ACTOR = RemoteActor(ap_actor=ap.ME, handle=f"@{USERNAME}@{WEBFINGER_DOMAIN}") | ||||||
|  |  | ||||||
|  |  | ||||||
| async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "ActorModel": | async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "ActorModel": | ||||||
|   | |||||||
| @@ -117,6 +117,8 @@ class Config(pydantic.BaseModel): | |||||||
|  |  | ||||||
|     custom_content_security_policy: str | None = None |     custom_content_security_policy: str | None = None | ||||||
|  |  | ||||||
|  |     webfinger_domain: str | None = None | ||||||
|  |  | ||||||
|     # Config items to make tests easier |     # Config items to make tests easier | ||||||
|     sqlalchemy_database: str | None = None |     sqlalchemy_database: str | None = None | ||||||
|     key_path: str | None = None |     key_path: str | None = None | ||||||
| @@ -168,6 +170,10 @@ ID = f"{_SCHEME}://{DOMAIN}" | |||||||
| if CONFIG.id: | if CONFIG.id: | ||||||
|     ID = CONFIG.id |     ID = CONFIG.id | ||||||
| USERNAME = CONFIG.username | USERNAME = CONFIG.username | ||||||
|  |  | ||||||
|  | # Allow to use @handle@webfinger-domain.tld while hosting the server at domain.tld | ||||||
|  | WEBFINGER_DOMAIN = CONFIG.webfinger_domain or DOMAIN | ||||||
|  |  | ||||||
| MANUALLY_APPROVES_FOLLOWERS = CONFIG.manually_approves_followers | MANUALLY_APPROVES_FOLLOWERS = CONFIG.manually_approves_followers | ||||||
| HIDES_FOLLOWERS = CONFIG.hides_followers | HIDES_FOLLOWERS = CONFIG.hides_followers | ||||||
| HIDES_FOLLOWING = CONFIG.hides_following | HIDES_FOLLOWING = CONFIG.hides_following | ||||||
|   | |||||||
| @@ -62,6 +62,7 @@ from app.config import DOMAIN | |||||||
| from app.config import ID | from app.config import ID | ||||||
| from app.config import USER_AGENT | from app.config import USER_AGENT | ||||||
| from app.config import USERNAME | from app.config import USERNAME | ||||||
|  | from app.config import WEBFINGER_DOMAIN | ||||||
| from app.config import is_activitypub_requested | from app.config import is_activitypub_requested | ||||||
| from app.config import verify_csrf_token | from app.config import verify_csrf_token | ||||||
| from app.customization import get_custom_router | from app.customization import get_custom_router | ||||||
| @@ -1260,7 +1261,7 @@ async def wellknown_webfinger(resource: str) -> JSONResponse: | |||||||
|         raise HTTPException(status_code=404) |         raise HTTPException(status_code=404) | ||||||
|  |  | ||||||
|     out = { |     out = { | ||||||
|         "subject": f"acct:{USERNAME}@{DOMAIN}", |         "subject": f"acct:{USERNAME}@{WEBFINGER_DOMAIN}", | ||||||
|         "aliases": [ID], |         "aliases": [ID], | ||||||
|         "links": [ |         "links": [ | ||||||
|             { |             { | ||||||
|   | |||||||
| @@ -20,12 +20,16 @@ async def test_fetch_actor(async_db_session: AsyncSession, respx_mock) -> None: | |||||||
|         public_key="pk", |         public_key="pk", | ||||||
|     ) |     ) | ||||||
|     respx_mock.get(ra.ap_id).mock(return_value=httpx.Response(200, json=ra.ap_actor)) |     respx_mock.get(ra.ap_id).mock(return_value=httpx.Response(200, json=ra.ap_actor)) | ||||||
|  |     respx_mock.get( | ||||||
|  |         "https://example.com/.well-known/webfinger", | ||||||
|  |         params={"resource": "acct%3Atoto%40example.com"}, | ||||||
|  |     ).mock(return_value=httpx.Response(200, json={"subject": "acct:toto@example.com"})) | ||||||
|  |  | ||||||
|     # When fetching this actor for the first time |     # When fetching this actor for the first time | ||||||
|     saved_actor = await fetch_actor(async_db_session, ra.ap_id) |     saved_actor = await fetch_actor(async_db_session, ra.ap_id) | ||||||
|  |  | ||||||
|     # Then it has been fetched and saved in DB |     # Then it has been fetched and saved in DB | ||||||
|     assert respx.calls.call_count == 1 |     assert respx.calls.call_count == 2 | ||||||
|     assert ( |     assert ( | ||||||
|         await async_db_session.execute(select(models.Actor)) |         await async_db_session.execute(select(models.Actor)) | ||||||
|     ).scalar_one().ap_id == saved_actor.ap_id |     ).scalar_one().ap_id == saved_actor.ap_id | ||||||
| @@ -38,7 +42,7 @@ async def test_fetch_actor(async_db_session: AsyncSession, respx_mock) -> None: | |||||||
|     assert ( |     assert ( | ||||||
|         await async_db_session.execute(select(func.count(models.Actor.id))) |         await async_db_session.execute(select(func.count(models.Actor.id))) | ||||||
|     ).scalar_one() == 1 |     ).scalar_one() == 1 | ||||||
|     assert respx.calls.call_count == 1 |     assert respx.calls.call_count == 2 | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_sqlalchemy_factory(db: Session) -> None: | def test_sqlalchemy_factory(db: Session) -> None: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user