diff --git a/app/indieauth.py b/app/indieauth.py index 57640b5..7e657ea 100644 --- a/app/indieauth.py +++ b/app/indieauth.py @@ -10,6 +10,8 @@ from fastapi import Form from fastapi import HTTPException from fastapi import Request from fastapi.responses import JSONResponse +from fastapi.security import HTTPBasic +from fastapi.security import HTTPBasicCredentials from loguru import logger from pydantic import BaseModel from sqlalchemy import select @@ -26,6 +28,8 @@ from app.redirect import redirect from app.utils import indieauth from app.utils.datetime import now +basic_auth = HTTPBasic() + router = APIRouter() @@ -496,19 +500,49 @@ async def indieauth_revocation_endpoint( @router.post("/token_introspection") async def oauth_introspection_endpoint( request: Request, - access_token_info: AccessTokenInfo = Depends(enforce_access_token), + credentials: HTTPBasicCredentials = Depends(basic_auth), + db_session: AsyncSession = Depends(get_db_session), token: str = Form(), ) -> JSONResponse: - # Ensure the requested token is the same as bearer token - if token != access_token_info.access_token: - raise HTTPException(status_code=401, detail="access token required") + registered_client = ( + await db_session.scalars( + select(models.OAuthClient).where( + models.OAuthClient.client_id == credentials.username, + models.OAuthClient.client_secret == credentials.password, + ) + ) + ).one_or_none() + if not registered_client: + raise HTTPException(status_code=401, detail="unauthenticated") + + access_token = ( + await db_session.scalars( + select(models.IndieAuthAccessToken) + .where(models.IndieAuthAccessToken.access_token == token) + .join( + models.IndieAuthAuthorizationRequest, + models.IndieAuthAccessToken.indieauth_authorization_request_id + == models.IndieAuthAuthorizationRequest.id, + ) + .where( + models.IndieAuthAuthorizationRequest.client_id == credentials.username + ) + ) + ).one_or_none() + if not access_token: + return JSONResponse(content={"active": False}) return JSONResponse( content={ "active": True, - "client_id": access_token_info.client_id, - "scope": " ".join(access_token_info.scopes), - "exp": access_token_info.exp, + "client_id": credentials.username, + "scope": access_token.scope, + "exp": int( + ( + access_token.created_at.replace(tzinfo=timezone.utc) + + timedelta(seconds=access_token.expires_in) + ).timestamp() + ), }, status_code=200, )