From ed214cf0e753cf54658c5644535413b7e0f70a95 Mon Sep 17 00:00:00 2001 From: Thomas Sileo Date: Sun, 18 Dec 2022 12:55:24 +0100 Subject: [PATCH] Add OAuth refresh token support --- ...0333f5a_add_oauth_refresh_token_support.py | 36 +++++++++++ app/indieauth.py | 59 ++++++++++++++----- app/main.py | 14 ++++- app/models.py | 2 + 4 files changed, 94 insertions(+), 17 deletions(-) create mode 100644 alembic/versions/2022_12_18_1126-a209f0333f5a_add_oauth_refresh_token_support.py diff --git a/alembic/versions/2022_12_18_1126-a209f0333f5a_add_oauth_refresh_token_support.py b/alembic/versions/2022_12_18_1126-a209f0333f5a_add_oauth_refresh_token_support.py new file mode 100644 index 0000000..8e486b9 --- /dev/null +++ b/alembic/versions/2022_12_18_1126-a209f0333f5a_add_oauth_refresh_token_support.py @@ -0,0 +1,36 @@ +"""Add OAuth refresh token support + +Revision ID: a209f0333f5a +Revises: 4ab54becec04 +Create Date: 2022-12-18 11:26:31.976348+00:00 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a209f0333f5a' +down_revision = '4ab54becec04' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op: + batch_op.add_column(sa.Column('refresh_token', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('was_refreshed', sa.Boolean(), server_default='0', nullable=False)) + batch_op.create_index(batch_op.f('ix_indieauth_access_token_refresh_token'), ['refresh_token'], unique=True) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_indieauth_access_token_refresh_token')) + batch_op.drop_column('was_refreshed') + batch_op.drop_column('refresh_token') + + # ### end Alembic commands ### diff --git a/app/indieauth.py b/app/indieauth.py index bf31a2b..31ea458 100644 --- a/app/indieauth.py +++ b/app/indieauth.py @@ -270,29 +270,54 @@ async def indieauth_token_endpoint( form_data = await request.form() logger.info(f"{form_data=}") grant_type = form_data.get("grant_type", "authorization_code") - if grant_type != "authorization_code": + if grant_type not in ["authorization_code", "refresh_token"]: raise ValueError(f"Invalid grant_type {grant_type}") - code = form_data["code"] - # These must match the params from the first request client_id = form_data["client_id"] - redirect_uri = form_data["redirect_uri"] - # code_verifier is optional for backward compat code_verifier = form_data.get("code_verifier") - is_code_valid, auth_code_request = await _check_auth_code( - db_session, - code=code, - client_id=client_id, - redirect_uri=redirect_uri, - code_verifier=code_verifier, - ) - if not is_code_valid or (auth_code_request and not auth_code_request.scope): - return JSONResponse( - content={"error": "invalid_grant"}, - status_code=400, + if grant_type == "authorization_code": + code = form_data["code"] + redirect_uri = form_data["redirect_uri"] + # code_verifier is optional for backward compat + is_code_valid, auth_code_request = await _check_auth_code( + db_session, + code=code, + client_id=client_id, + redirect_uri=redirect_uri, + code_verifier=code_verifier, ) + if not is_code_valid or (auth_code_request and not auth_code_request.scope): + return JSONResponse( + content={"error": "invalid_grant"}, + status_code=400, + ) + + elif grant_type == "refresh_token": + refresh_token = form_data["refresh_token"] + access_token = ( + await db_session.scalars( + select(models.IndieAuthAccessToken) + .where( + models.IndieAuthAccessToken.refresh_token == refresh_token, + models.IndieAuthAccessToken.was_refreshed.is_(False), + ) + .options( + joinedload( + models.IndieAuthAccessToken.indieauth_authorization_request + ) + ) + ) + ).one_or_none() + if not access_token: + raise ValueError("invalid refresh token") + + if access_token.indieauth_authorization_request.client_id != client_id: + raise ValueError("invalid client ID") + + auth_code_request = access_token.indieauth_authorization_request + access_token.was_refreshed = True if not auth_code_request: raise ValueError("Should never happen") @@ -300,6 +325,7 @@ async def indieauth_token_endpoint( access_token = models.IndieAuthAccessToken( indieauth_authorization_request_id=auth_code_request.id, access_token=secrets.token_urlsafe(32), + refresh_token=secrets.token_urlsafe(32), expires_in=3600, scope=auth_code_request.scope, ) @@ -309,6 +335,7 @@ async def indieauth_token_endpoint( return JSONResponse( content={ "access_token": access_token.access_token, + "refresh_token": access_token.refresh_token, "token_type": "Bearer", "scope": auth_code_request.scope, "me": config.ID + "/", diff --git a/app/main.py b/app/main.py index 277b385..31c1104 100644 --- a/app/main.py +++ b/app/main.py @@ -631,6 +631,19 @@ async def outbox( ) +@app.post("/outbox") +async def post_inbox( + request: Request, + db_session: AsyncSession = Depends(get_db_session), + access_token_info: indieauth.AccessTokenInfo = Depends( + indieauth.enforce_access_token + ), +) -> ActivityPubResponse: + payload = await request.json() + logger.info(f"{payload=}") + raise ValueError("TODO") + + @app.get("/featured") async def featured( db_session: AsyncSession = Depends(get_db_session), @@ -1055,7 +1068,6 @@ async def get_inbox( page: bool | None = None, next_cursor: str | None = None, ) -> ActivityPubResponse: - logger.info(f"{page=}/{next_cursor=}") where = [ models.InboxObject.ap_type.in_( ["Create", "Follow", "Like", "Announce", "Undo", "Update"] diff --git a/app/models.py b/app/models.py index 393538c..0cf15c1 100644 --- a/app/models.py +++ b/app/models.py @@ -471,9 +471,11 @@ class IndieAuthAccessToken(Base): ) access_token = Column(String, nullable=False, unique=True, index=True) + refresh_token = Column(String, nullable=True, unique=True, index=True) expires_in = Column(Integer, nullable=False) scope = Column(String, nullable=False) is_revoked = Column(Boolean, nullable=False, default=False) + was_refreshed = Column(Boolean, nullable=False, default=False, server_default="0") class OAuthClient(Base):