diff --git a/.cursor/rules/api-layer.mdc b/.cursor/rules/api-layer.mdc index 309690801..4ce1c8d18 100644 --- a/.cursor/rules/api-layer.mdc +++ b/.cursor/rules/api-layer.mdc @@ -40,6 +40,7 @@ api/ - For direct auth: validates credential format - For OAuth: handles authorization flows - For auth providers: validates provider exists and supports the source +- **POST /source-connections/{id}/verify-oauth**: Verifies claim-token ownership and triggers deferred sync. Called after the OAuth callback completes to prove the caller that initiated the flow is the one completing it. Required for all browser-based OAuth flows. ## Core Components diff --git a/.cursor/rules/connect-widget.mdc b/.cursor/rules/connect-widget.mdc index 80d328af9..b1d82ff00 100644 --- a/.cursor/rules/connect-widget.mdc +++ b/.cursor/rules/connect-widget.mdc @@ -109,7 +109,7 @@ requestClose(reason: "success" | "cancel" | "error"); ``` ### 3. OAuth Flow Security -OAuth uses same-origin popups with validated messaging: +OAuth uses same-origin popups with validated messaging and claim-token verification: ```typescript // oauth-callback.tsx posts to same origin window.opener.postMessage({ type: "OAUTH_COMPLETE", ...result }, window.location.origin); @@ -121,6 +121,20 @@ const handler = (event: MessageEvent) => { }; ``` +**Claim-token verification (two-step completion):** +After the OAuth popup completes, the Connect widget calls `verifyOAuth` with the claim token before considering the flow complete. The claim token is returned by the initial `createSourceConnection` call and stored in `claimTokenRef`. This ensures the caller that initiated the OAuth flow is the same one completing it. + +See `useOAuthFlow.ts` lines 66-72 for the implementation: +```typescript +if (claimTokenRef.current) { + await apiClient.verifyOAuth( + result.source_connection_id, + claimTokenRef.current, + ); + claimTokenRef.current = null; +} +``` + ### 4. Theming System Fully customizable via CSS variables passed from parent: ```typescript diff --git a/backend/airweave/api/v1/endpoints/connect.py b/backend/airweave/api/v1/endpoints/connect.py index 2faba4be0..c26f7468b 100644 --- a/backend/airweave/api/v1/endpoints/connect.py +++ b/backend/airweave/api/v1/endpoints/connect.py @@ -40,6 +40,7 @@ ConnectSessionCreate, ConnectSessionResponse, ) +from airweave.schemas.source_connection import VerifyOAuthRequest router = TrailingSlashRouter() @@ -174,6 +175,24 @@ async def create_source_connection( return await svc.create_source_connection(db, source_connection_in, session, session_token) +@router.post( + "/source-connections/{connection_id}/verify-oauth", + response_model=schemas.SourceConnection, +) +async def verify_oauth( + connection_id: UUID, + body: VerifyOAuthRequest, + db: AsyncSession = Depends(get_db), + session: ConnectSessionContext = Depends(deps.get_connect_session), + svc: ConnectServiceProtocol = Inject(ConnectServiceProtocol), +) -> schemas.SourceConnection: + """Verify OAuth flow ownership via Connect session. + + Authentication: Bearer + """ + return await svc.verify_oauth(db, connection_id, body.claim_token, session) + + # ============================================================================= # Sync Jobs # ============================================================================= diff --git a/backend/airweave/api/v1/endpoints/source_connections.py b/backend/airweave/api/v1/endpoints/source_connections.py index b8ea8e9d8..2d280b1d4 100644 --- a/backend/airweave/api/v1/endpoints/source_connections.py +++ b/backend/airweave/api/v1/endpoints/source_connections.py @@ -38,6 +38,7 @@ RateLimitErrorResponse, ValidationErrorResponse, ) +from airweave.schemas.source_connection import VerifyOAuthRequest logger = logging.getLogger(__name__) @@ -97,6 +98,35 @@ async def oauth_callback( ) +@router.post( + "/{source_connection_id}/verify-oauth", + response_model=schemas.SourceConnection, + summary="Verify OAuth Flow", + description="""Verify ownership of an OAuth flow by presenting the claim token +returned during connection creation. This triggers the deferred sync.""", + responses={ + 200: {"model": schemas.SourceConnection, "description": "Verified source connection"}, + 403: {"description": "Invalid claim token or identity mismatch"}, + 404: {"model": NotFoundErrorResponse, "description": "Source Connection Not Found"}, + }, +) +async def verify_oauth( + *, + db: AsyncSession = Depends(get_db), + source_connection_id: UUID = Path(...), + body: VerifyOAuthRequest, + ctx: ApiContext = Depends(deps.get_context), + oauth_callback_svc: OAuthCallbackServiceProtocol = Inject(OAuthCallbackServiceProtocol), +) -> schemas.SourceConnection: + """Verify OAuth flow ownership and trigger deferred sync.""" + return await oauth_callback_svc.verify_oauth_flow( + db, + source_connection_id=source_connection_id, + claim_token=body.claim_token, + ctx=ctx, + ) + + @router.post( "/", response_model=schemas.SourceConnection, diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index f95e91d3f..82644c33c 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -342,20 +342,6 @@ def create_container(settings: Settings) -> Container: deletion_service=deletion_service, ) - # ----------------------------------------------------------------- - # Connect domain service - # ----------------------------------------------------------------- - from airweave.domains.connect.service import ConnectService - from airweave.domains.organizations.repository import OrganizationRepository as ConnectOrgRepo - - connect_service = ConnectService( - source_connection_service=source_connection_service, - source_service=source_deps["source_service"], - org_repo=ConnectOrgRepo(), - collection_repo=source_deps["collection_repo"], - sync_job_repo=source_deps["sync_job_repo"], - ) - # ----------------------------------------------------------------- # Embedder registries + instances (deployment-wide singletons) # ----------------------------------------------------------------- @@ -410,6 +396,21 @@ def create_container(settings: Settings) -> Container: credential_encryptor=encryptor, ) + # ----------------------------------------------------------------- + # Connect domain service (after oauth_callback_svc for DI) + # ----------------------------------------------------------------- + from airweave.domains.connect.service import ConnectService + from airweave.domains.organizations.repository import OrganizationRepository as ConnectOrgRepo + + connect_service = ConnectService( + source_connection_service=source_connection_service, + source_service=source_deps["source_service"], + org_repo=ConnectOrgRepo(), + collection_repo=source_deps["collection_repo"], + sync_job_repo=source_deps["sync_job_repo"], + oauth_callback_service=oauth_callback_svc, + ) + # ----------------------------------------------------------------- # Browse tree service # ----------------------------------------------------------------- diff --git a/backend/airweave/domains/connect/fakes/service.py b/backend/airweave/domains/connect/fakes/service.py index 3a1205089..3eab37d8c 100644 --- a/backend/airweave/domains/connect/fakes/service.py +++ b/backend/airweave/domains/connect/fakes/service.py @@ -131,6 +131,16 @@ async def create_source_connection( ) raise NotImplementedError("FakeConnectService.create_source_connection not seeded") + async def verify_oauth( + self, + db: AsyncSession, + connection_id: UUID, + claim_token: str, + session: ConnectSessionContext, + ) -> Any: + self._calls.append(("verify_oauth", connection_id, claim_token, session)) + raise NotImplementedError("FakeConnectService.verify_oauth not seeded") + async def get_connection_jobs( self, db: AsyncSession, diff --git a/backend/airweave/domains/connect/protocols.py b/backend/airweave/domains/connect/protocols.py index c32be1b64..3d7b43008 100644 --- a/backend/airweave/domains/connect/protocols.py +++ b/backend/airweave/domains/connect/protocols.py @@ -82,6 +82,16 @@ async def create_source_connection( """Create a source connection via Connect session.""" ... + async def verify_oauth( + self, + db: AsyncSession, + connection_id: UUID, + claim_token: str, + session: ConnectSessionContext, + ) -> schemas.SourceConnection: + """Verify OAuth flow ownership via Connect session.""" + ... + async def get_connection_jobs( self, db: AsyncSession, diff --git a/backend/airweave/domains/connect/service.py b/backend/airweave/domains/connect/service.py index 014512d4f..d8a2d3e3c 100644 --- a/backend/airweave/domains/connect/service.py +++ b/backend/airweave/domains/connect/service.py @@ -18,6 +18,7 @@ from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connect.protocols import ConnectServiceProtocol from airweave.domains.connect.types import MODES_CREATE, MODES_DELETE, MODES_VIEW +from airweave.domains.oauth.protocols import OAuthCallbackServiceProtocol from airweave.domains.organizations.protocols import OrganizationRepositoryProtocol from airweave.domains.source_connections.protocols import SourceConnectionServiceProtocol from airweave.domains.sources.protocols import SourceServiceProtocol @@ -44,12 +45,14 @@ def __init__( # noqa: D107 org_repo: OrganizationRepositoryProtocol, collection_repo: CollectionRepositoryProtocol, sync_job_repo: SyncJobRepositoryProtocol, + oauth_callback_service: OAuthCallbackServiceProtocol, ) -> None: self._sc_service = source_connection_service self._source_service = source_service self._org_repo = org_repo self._collection_repo = collection_repo self._sync_job_repo = sync_job_repo + self._oauth_callback_service = oauth_callback_service # ------------------------------------------------------------------ # Guards (private — all access checks live on the service) @@ -346,6 +349,24 @@ async def create_source_connection( return result # type: ignore[return-value] + async def verify_oauth( + self, + db: AsyncSession, + connection_id: UUID, + claim_token: str, + session: ConnectSessionContext, + ) -> schemas.SourceConnection: + """Verify OAuth flow ownership via Connect session.""" + self._check_mode(session, MODES_CREATE, "verifying OAuth flow") + ctx = await self._build_context(db, session) + + return await self._oauth_callback_service.verify_oauth_flow( + db, + source_connection_id=connection_id, + claim_token=claim_token, + ctx=ctx, + ) + # ------------------------------------------------------------------ # Sync jobs # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/connect/tests/conftest.py b/backend/airweave/domains/connect/tests/conftest.py index d4cbb1fbe..601ab3bb1 100644 --- a/backend/airweave/domains/connect/tests/conftest.py +++ b/backend/airweave/domains/connect/tests/conftest.py @@ -84,10 +84,13 @@ def sync_job_repo(): @pytest.fixture def connect_service(org_repo, sc_service, source_service, collection_repo, sync_job_repo): + from unittest.mock import AsyncMock + return ConnectService( source_connection_service=sc_service, source_service=source_service, org_repo=org_repo, collection_repo=collection_repo, sync_job_repo=sync_job_repo, + oauth_callback_service=AsyncMock(), ) diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index 7a2b179b3..fb505ed07 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -4,7 +4,10 @@ source_connection_service_helpers completion logic with proper DI. """ +import hashlib +import hmac from collections.abc import Mapping +from datetime import datetime, timezone from typing import Any, Dict from uuid import UUID, uuid4 @@ -13,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas -from airweave.api.context import ApiContext +from airweave.api.context import ApiContext, ConnectContext from airweave.core.events.source_connection import SourceConnectionLifecycleEvent from airweave.core.events.sync import SyncLifecycleEvent from airweave.core.logging import logger @@ -156,11 +159,18 @@ async def complete_oauth2_callback( if not init_session: raise HTTPException(status_code=404, detail="OAuth2 session not found or expired") + if init_session.expires_at < datetime.now(timezone.utc): + raise HTTPException(status_code=410, detail="OAuth session expired") + if init_session.status != ConnectionInitStatus.PENDING: raise HTTPException( status_code=400, detail=f"OAuth session already {init_session.status}" ) + init_session.status = ConnectionInitStatus.IN_PROGRESS + db.add(init_session) + await db.flush() + ctx = await self._reconstruct_context(db, init_session) source_conn_shell = await self._sc_repo.get_by_init_session( @@ -210,11 +220,18 @@ async def complete_oauth1_callback( ), ) + if init_session.expires_at < datetime.now(timezone.utc): + raise HTTPException(status_code=410, detail="OAuth session expired") + if init_session.status != ConnectionInitStatus.PENDING: raise HTTPException( status_code=400, detail=f"OAuth session already {init_session.status}" ) + init_session.status = ConnectionInitStatus.IN_PROGRESS + db.add(init_session) + await db.flush() + ctx = await self._reconstruct_context(db, init_session) source_conn_shell = await self._sc_repo.get_by_init_session( @@ -343,6 +360,7 @@ async def _complete_oauth1_connection( auth_method_to_save, is_oauth1=True, ctx=ctx, + has_claim_token=bool(init_session.claim_token_hash), ) async def _complete_oauth2_connection( @@ -398,6 +416,7 @@ async def _complete_oauth2_connection( auth_method_to_save, is_oauth1=False, ctx=ctx, + has_claim_token=bool(init_session.claim_token_hash), ) # ------------------------------------------------------------------ @@ -415,6 +434,7 @@ async def _complete_connection_common( # noqa: C901 auth_method_to_save: AuthenticationMethod, is_oauth1: bool, ctx: ApiContext, + has_claim_token: bool = False, ) -> SourceConnection: """Common logic for completing OAuth connections (shared by OAuth1/OAuth2).""" validated_config = self._validate_config(source_entry, payload.get("config")) @@ -531,12 +551,13 @@ async def _complete_connection_common( # noqa: C901 uow=uow, ) - await self._init_session_repo.mark_completed( - uow.session, - session_id=init_session_id, - final_connection_id=sc_update["connection_id"], - ctx=ctx, - ) + if not has_claim_token: + await self._init_session_repo.mark_completed( + uow.session, + session_id=init_session_id, + final_connection_id=sc_update["connection_id"], + ctx=ctx, + ) await uow.commit() await uow.session.refresh(source_conn) @@ -573,6 +594,20 @@ async def _validate_oauth2_token_or_raise( # Private: finalization (response + sync trigger) # ------------------------------------------------------------------ + async def _has_claim_token( + self, + db: AsyncSession, + source_conn: SourceConnection, + ctx: ApiContext, + ) -> bool: + """Check if source connection's init session has a claim token.""" + if not source_conn.connection_init_session_id: + return False + session = await self._init_session_repo.get( + db, id=source_conn.connection_init_session_id, ctx=ctx + ) + return bool(session and session.claim_token_hash) + async def _finalize_callback( self, db: AsyncSession, @@ -582,64 +617,18 @@ async def _finalize_callback( """Build response and trigger sync workflow if needed.""" source_conn_response = await self._response_builder.build_response(db, source_conn, ctx) - if source_conn.sync_id: - sync = await self._sync_repo.get(db, id=source_conn.sync_id, ctx=ctx) - if sync: - jobs = await self._sync_job_repo.get_all_by_sync_id(db, sync_id=sync.id, ctx=ctx) - if jobs and len(jobs) > 0: - sync_job = jobs[0] - if sync_job.status == SyncJobStatus.PENDING: - collection = await self._collection_repo.get_by_readable_id( - db, readable_id=source_conn.readable_collection_id, ctx=ctx - ) - if collection: - collection_schema = schemas.CollectionRecord.model_validate( - collection, from_attributes=True - ) - sync_job_schema = schemas.SyncJob.model_validate( - sync_job, from_attributes=True - ) - sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) - - if not source_conn.connection_id: - raise ValueError( - f"Source connection {source_conn.id} has no connection_id" - ) - conn_model = await self._connection_repo.get( - db, id=source_conn.connection_id, ctx=ctx - ) - if not conn_model: - raise ValueError( - f"Connection {source_conn.connection_id} not found" - ) - connection_schema = schemas.Connection.model_validate( - conn_model, from_attributes=True - ) - - try: - await self._event_bus.publish( - SyncLifecycleEvent.pending( - organization_id=ctx.organization.id, - source_connection_id=source_conn.id, - sync_job_id=sync_job_schema.id, - sync_id=sync_schema.id, - collection_id=collection_schema.id, - source_type=connection_schema.short_name, - collection_name=collection_schema.name, - collection_readable_id=collection_schema.readable_id, - ) - ) - except Exception as e: - ctx.logger.warning(f"Failed to publish sync.pending event: {e}") - - await self._temporal_workflow_service.run_source_connection_workflow( - sync=sync_schema, - sync_job=sync_job_schema, - collection=collection_schema, - connection=connection_schema, - ctx=ctx, - ) + should_defer_sync = bool( + source_conn.connection_init_session_id + and await self._has_claim_token(db, source_conn, ctx) + ) + + if source_conn.sync_id and not should_defer_sync: + await self._run_sync_workflow(db, source_conn, ctx) + # auth_completed fires at callback time regardless of whether the sync + # is deferred. Auth *is* complete once the provider redirects back; + # the subsequent verify-oauth call proves *who* initiated the flow, + # it does not change the auth state. await self._event_bus.publish( SourceConnectionLifecycleEvent.auth_completed( organization_id=ctx.organization.id, @@ -651,6 +640,159 @@ async def _finalize_callback( return source_conn_response + # ------------------------------------------------------------------ + # Verify OAuth flow ownership + # ------------------------------------------------------------------ + + async def verify_oauth_flow( + self, + db: AsyncSession, + *, + source_connection_id: UUID, + claim_token: str, + ctx: ApiContext | ConnectContext, + ) -> SourceConnectionSchema: + """Verify OAuth flow ownership via claim token and trigger deferred sync.""" + source_conn = await self._sc_repo.get(db, id=source_connection_id, ctx=ctx) + if not source_conn: + raise HTTPException(status_code=404, detail="Source connection not found") + + if not source_conn.connection_init_session_id: + raise HTTPException(status_code=400, detail="No OAuth session for this connection") + + init_session = await self._init_session_repo.get( + db, id=source_conn.connection_init_session_id, ctx=ctx + ) + if not init_session or not init_session.claim_token_hash: + raise HTTPException(status_code=400, detail="No claim token on OAuth session") + + expected_hash = hashlib.sha256(claim_token.encode()).hexdigest() + if not hmac.compare_digest(expected_hash, init_session.claim_token_hash): + raise HTTPException(status_code=403, detail="Invalid claim token") + + if init_session.status != ConnectionInitStatus.IN_PROGRESS: + raise HTTPException( + status_code=400, + detail=f"OAuth session is {init_session.status}, expected in_progress", + ) + + # Identity check last: avoids leaking session state to wrong-identity callers + caller_user_id = getattr(ctx, "user_id", None) + caller_session_id = getattr(ctx, "session_id", None) + has_initiator = ( + init_session.initiator_user_id is not None + or init_session.initiator_session_id is not None + ) + if has_initiator: + user_match = ( + caller_user_id is not None + and init_session.initiator_user_id is not None + and caller_user_id == init_session.initiator_user_id + ) + session_match = ( + caller_session_id is not None + and init_session.initiator_session_id is not None + and caller_session_id == init_session.initiator_session_id + ) + if not user_match and not session_match: + raise HTTPException(status_code=403, detail="Caller identity mismatch") + + # Mark session completed + await self._init_session_repo.mark_completed( + db, + session_id=init_session.id, + final_connection_id=source_conn.connection_id, + ctx=ctx, + ) + await db.flush() + + # Trigger deferred sync + api_ctx = ( + ctx + if isinstance(ctx, ApiContext) + else await self._reconstruct_context(db, init_session) + ) + return await self._trigger_deferred_sync(db, source_conn, api_ctx) + + async def _trigger_deferred_sync( + self, + db: AsyncSession, + source_conn: SourceConnection, + ctx: ApiContext, + ) -> SourceConnectionSchema: + """Trigger sync workflow that was deferred during callback.""" + source_conn_response = await self._response_builder.build_response(db, source_conn, ctx) + + if source_conn.sync_id: + await self._run_sync_workflow(db, source_conn, ctx) + + return source_conn_response + + async def _run_sync_workflow( + self, + db: AsyncSession, + source_conn: SourceConnection, + ctx: ApiContext, + ) -> None: + """Publish sync.pending event and start the Temporal workflow. + + Shared by the immediate callback path and the deferred verify path. + """ + sync = await self._sync_repo.get(db, id=source_conn.sync_id, ctx=ctx) + if not sync: + return + + jobs = await self._sync_job_repo.get_all_by_sync_id(db, sync_id=sync.id, ctx=ctx) + if not jobs: + return + + sync_job = jobs[0] + if sync_job.status != SyncJobStatus.PENDING: + return + + collection = await self._collection_repo.get_by_readable_id( + db, readable_id=source_conn.readable_collection_id, ctx=ctx + ) + if not collection: + return + + collection_schema = schemas.CollectionRecord.model_validate( + collection, from_attributes=True + ) + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) + + if not source_conn.connection_id: + raise ValueError(f"Source connection {source_conn.id} has no connection_id") + conn_model = await self._connection_repo.get(db, id=source_conn.connection_id, ctx=ctx) + if not conn_model: + raise ValueError(f"Connection {source_conn.connection_id} not found") + connection_schema = schemas.Connection.model_validate(conn_model, from_attributes=True) + + try: + await self._event_bus.publish( + SyncLifecycleEvent.pending( + organization_id=ctx.organization.id, + source_connection_id=source_conn.id, + sync_job_id=sync_job_schema.id, + sync_id=sync_schema.id, + collection_id=collection_schema.id, + source_type=connection_schema.short_name, + collection_name=collection_schema.name, + collection_readable_id=collection_schema.readable_id, + ) + ) + except Exception as e: + ctx.logger.warning(f"Failed to publish sync.pending event: {e}") + + await self._temporal_workflow_service.run_source_connection_workflow( + sync=sync_schema, + sync_job=sync_job_schema, + collection=collection_schema, + connection=connection_schema, + ctx=ctx, + ) + # ------------------------------------------------------------------ # Private: inline helpers # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/oauth/fakes/flow_service.py b/backend/airweave/domains/oauth/fakes/flow_service.py index d6d51fb6d..f1b29d586 100644 --- a/backend/airweave/domains/oauth/fakes/flow_service.py +++ b/backend/airweave/domains/oauth/fakes/flow_service.py @@ -183,6 +183,9 @@ async def create_init_session( redirect_url: Optional[str] = None, template_configs: Optional[dict] = None, additional_overrides: Optional[Dict[str, Any]] = None, + initiator_user_id: Optional[UUID] = None, + initiator_session_id: Optional[UUID] = None, + claim_token_hash: Optional[str] = None, ) -> Any: self._calls.append(("create_init_session", short_name, state)) self._last_create_init_session_kwargs = { @@ -196,6 +199,9 @@ async def create_init_session( "redirect_url": redirect_url, "template_configs": template_configs, "additional_overrides": additional_overrides, + "initiator_user_id": initiator_user_id, + "initiator_session_id": initiator_session_id, + "claim_token_hash": claim_token_hash, } return type("InitSession", (), {"id": uuid4()})() diff --git a/backend/airweave/domains/oauth/fakes/repository.py b/backend/airweave/domains/oauth/fakes/repository.py index 6ed3e959d..e649917b2 100644 --- a/backend/airweave/domains/oauth/fakes/repository.py +++ b/backend/airweave/domains/oauth/fakes/repository.py @@ -1,7 +1,7 @@ """Fake repositories for OAuth domain testing.""" from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast from uuid import UUID, uuid4 from sqlalchemy.ext.asyncio import AsyncSession @@ -57,13 +57,25 @@ class FakeOAuthInitSessionRepository: def __init__(self) -> None: self._store_by_state: dict[str, ConnectionInitSession] = {} self._store_by_token: dict[str, ConnectionInitSession] = {} + self._store_by_id: dict[UUID, ConnectionInitSession] = {} self._calls: list[tuple[Any, ...]] = [] def seed_by_state(self, state: str, obj: ConnectionInitSession) -> None: self._store_by_state[state] = obj + if hasattr(obj, "id") and obj.id: + self._store_by_id[cast(UUID, obj.id)] = obj def seed_by_oauth_token(self, oauth_token: str, obj: ConnectionInitSession) -> None: self._store_by_token[oauth_token] = obj + if hasattr(obj, "id") and obj.id: + self._store_by_id[cast(UUID, obj.id)] = obj + + def seed_by_id(self, id: UUID, obj: ConnectionInitSession) -> None: + self._store_by_id[id] = obj + + async def get(self, db: AsyncSession, *, id: UUID, ctx: Any) -> Optional[ConnectionInitSession]: + self._calls.append(("get", id)) + return self._store_by_id.get(id) async def get_by_state_no_auth( self, db: AsyncSession, *, state: str diff --git a/backend/airweave/domains/oauth/flow_service.py b/backend/airweave/domains/oauth/flow_service.py index b4353a4ca..06484fd34 100644 --- a/backend/airweave/domains/oauth/flow_service.py +++ b/backend/airweave/domains/oauth/flow_service.py @@ -300,6 +300,9 @@ async def create_init_session( redirect_url: Optional[str] = None, template_configs: Optional[dict] = None, additional_overrides: Optional[Dict[str, Any]] = None, + initiator_user_id: Optional[UUID] = None, + initiator_session_id: Optional[UUID] = None, + claim_token_hash: Optional[str] = None, ) -> ConnectionInitSession: """Persist an init session for a new OAuth flow. @@ -334,6 +337,9 @@ async def create_init_session( "status": ConnectionInitStatus.PENDING, "expires_at": expires_at, "redirect_session_id": redirect_session_id, + "initiator_user_id": initiator_user_id, + "initiator_session_id": initiator_session_id, + "claim_token_hash": claim_token_hash, }, ctx=ctx, uow=uow, diff --git a/backend/airweave/domains/oauth/protocols.py b/backend/airweave/domains/oauth/protocols.py index 225a0cc55..6c80cc11d 100644 --- a/backend/airweave/domains/oauth/protocols.py +++ b/backend/airweave/domains/oauth/protocols.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession -from airweave.api.context import ApiContext +from airweave.api.context import ApiContext, ConnectContext from airweave.core.logging import ContextualLogger from airweave.db.unit_of_work import UnitOfWork from airweave.domains.oauth.types import OAuth1TokenResponse, OAuthBrowserInitiationResult @@ -156,6 +156,16 @@ async def create( """Persist a new init session.""" ... + async def get( + self, + db: AsyncSession, + *, + id: UUID, + ctx: ApiContext, + ) -> Optional[ConnectionInitSession]: + """Fetch an init session by ID (org-scoped).""" + ... + async def mark_completed( self, db: AsyncSession, @@ -277,6 +287,9 @@ async def create_init_session( redirect_url: Optional[str] = None, template_configs: Optional[dict] = None, additional_overrides: Optional[Dict[str, Any]] = None, + initiator_user_id: Optional[UUID] = None, + initiator_session_id: Optional[UUID] = None, + claim_token_hash: Optional[str] = None, ) -> ConnectionInitSession: """Persist an init session for a new OAuth flow.""" ... @@ -329,3 +342,14 @@ async def complete_oauth1_callback( Exchange verifier, wire credential + connection, trigger sync. """ ... + + async def verify_oauth_flow( + self, + db: AsyncSession, + *, + source_connection_id: UUID, + claim_token: str, + ctx: ApiContext | ConnectContext, + ) -> SourceConnectionSchema: + """Verify OAuth flow ownership via claim token and trigger deferred sync.""" + ... diff --git a/backend/airweave/domains/oauth/tests/test_callback_service.py b/backend/airweave/domains/oauth/tests/test_callback_service.py index ab7d732b5..84af26a40 100644 --- a/backend/airweave/domains/oauth/tests/test_callback_service.py +++ b/backend/airweave/domains/oauth/tests/test_callback_service.py @@ -62,6 +62,10 @@ def _init_session( session_id: UUID | None = None, payload: dict | None = None, overrides: dict | None = None, + expires_at: datetime | None = None, + initiator_user_id: UUID | None = None, + initiator_session_id: UUID | None = None, + claim_token_hash: str | None = None, ) -> ConnectionInitSession: return ConnectionInitSession( id=session_id or SESSION_ID, @@ -71,7 +75,10 @@ def _init_session( organization_id=organization_id or ORG_ID, payload=payload or {}, overrides=overrides or {}, - expires_at=datetime(2099, 1, 1, tzinfo=timezone.utc), + expires_at=expires_at or datetime(2099, 1, 1, tzinfo=timezone.utc), + initiator_user_id=initiator_user_id, + initiator_session_id=initiator_session_id, + claim_token_hash=claim_token_hash, ) @@ -137,6 +144,7 @@ def _service( DB = AsyncMock() +DB.add = MagicMock() # --------------------------------------------------------------------------- @@ -876,6 +884,72 @@ async def commit(self): svc._sync_lifecycle.provision_sync.assert_not_awaited() + async def test_claim_token_session_skips_mark_completed(self): + svc = _service() + conn_id = uuid4() + svc._credential_repo.create = AsyncMock(return_value=SimpleNamespace(id=uuid4())) + svc._connection_repo.create = AsyncMock(return_value=SimpleNamespace(id=conn_id)) + svc._collection_repo.get_by_readable_id = AsyncMock( + return_value=SimpleNamespace(id=uuid4(), readable_id="col-abc") + ) + sc_id = uuid4() + svc._sc_repo.update = AsyncMock( + return_value=SimpleNamespace(id=sc_id, connection_id=conn_id) + ) + svc._init_session_repo.mark_completed = AsyncMock() + svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + sync_id = uuid4() + svc._sync_lifecycle.provision_sync = AsyncMock( + return_value=SimpleNamespace(sync_id=sync_id) + ) + + from airweave.domains.oauth import callback_service as callback_module + + class _FakeUOW: + def __init__(self, db): + self.session = db + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def commit(self): + return None + + db = AsyncMock() + db.flush = AsyncMock() + db.refresh = AsyncMock() + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(callback_module, "UnitOfWork", _FakeUOW) + try: + source_entry = SimpleNamespace( + short_name="github", + name="GitHub", + auth_config_ref=type("GitHubAuth", (), {}), + oauth_type="access_only", + config_ref=None, + source_class_ref=SimpleNamespace(federated_search=False), + ) + shell = _source_conn_shell() + await svc._complete_connection_common( + db, + source_entry, + shell, + SESSION_ID, + {"name": "n", "readable_collection_id": "col-abc"}, + {"access_token": "tok"}, + AuthenticationMethod.OAUTH_BROWSER, + is_oauth1=False, + ctx=_ctx(), + has_claim_token=True, + ) + finally: + monkeypatch.undo() + + svc._init_session_repo.mark_completed.assert_not_awaited() + async def test_non_federated_source_provisions_sync_with_cron_schedule(self): svc = _service() conn_id = uuid4() @@ -959,7 +1033,12 @@ async def test_no_sync_id_just_returns_response(self): event_bus = AsyncMock() event_bus.publish = AsyncMock() - source_conn = SimpleNamespace(sync_id=None, id=uuid4(), connection_id=uuid4()) + source_conn = SimpleNamespace( + sync_id=None, + id=uuid4(), + connection_id=uuid4(), + connection_init_session_id=None, + ) svc = _service(response_builder=builder, event_bus=event_bus) result = await svc._finalize_callback(DB, source_conn, _ctx()) @@ -985,6 +1064,7 @@ async def test_triggers_workflow_when_pending_job_exists(self): sync_id=sync_id, connection_id=conn_id, readable_collection_id="col-abc", + connection_init_session_id=None, ) # Seed sync repo @@ -1084,6 +1164,7 @@ async def test_no_pending_jobs_skips_workflow(self): sync_id=sync_id, connection_id=uuid4(), readable_collection_id="col-abc", + connection_init_session_id=None, ) sync_repo = FakeSyncRepository() @@ -1118,6 +1199,7 @@ async def test_running_job_skips_workflow(self): sync_id=sync_id, connection_id=conn_id, readable_collection_id="col-abc", + connection_init_session_id=None, ) from airweave import schemas @@ -1174,6 +1256,7 @@ async def test_missing_connection_id_raises_value_error(self): sync_id=sync_id, connection_id=None, readable_collection_id="col-abc", + connection_init_session_id=None, ) from airweave import schemas @@ -1240,7 +1323,12 @@ async def test_auth_completed_event_failure_is_fatal(self): event_bus = AsyncMock() event_bus.publish = AsyncMock(side_effect=RuntimeError("pub-fail")) ctx = _ctx() - source_conn = SimpleNamespace(sync_id=None, id=uuid4(), connection_id=uuid4()) + source_conn = SimpleNamespace( + sync_id=None, + id=uuid4(), + connection_id=uuid4(), + connection_init_session_id=None, + ) svc = _service(response_builder=builder, event_bus=event_bus) with pytest.raises(RuntimeError, match="pub-fail"): await svc._finalize_callback(DB, source_conn, ctx) @@ -1257,6 +1345,7 @@ async def test_pending_job_with_connection_executes_event_payload_path(self): sync_id=sync_id, connection_id=conn_id, readable_collection_id="col-abc", + connection_init_session_id=None, ) from airweave import schemas @@ -1365,3 +1454,437 @@ def test_encryptor_is_stored(self): encryptor = FakeCredentialEncryptor() svc = _service(credential_encryptor=encryptor) assert svc._credential_encryptor is encryptor + + +# --------------------------------------------------------------------------- +# Expiry enforcement +# --------------------------------------------------------------------------- + + +class TestExpiryEnforcement: + async def test_expired_oauth2_session_raises_410(self): + repo = FakeOAuthInitSessionRepository() + session = _init_session( + expires_at=datetime(2020, 1, 1, tzinfo=timezone.utc), + ) + repo.seed_by_state("state-abc", session) + + svc = _service(init_session_repo=repo) + with pytest.raises(HTTPException) as exc: + await svc.complete_oauth2_callback(DB, state="state-abc", code="c") + assert exc.value.status_code == 410 + assert "expired" in exc.value.detail.lower() + + async def test_expired_oauth1_session_raises_410(self): + repo = FakeOAuthInitSessionRepository() + session = _init_session( + expires_at=datetime(2020, 1, 1, tzinfo=timezone.utc), + ) + repo.seed_by_oauth_token("tok1", session) + + svc = _service(init_session_repo=repo) + with pytest.raises(HTTPException) as exc: + await svc.complete_oauth1_callback(DB, oauth_token="tok1", oauth_verifier="v") + assert exc.value.status_code == 410 + + +# --------------------------------------------------------------------------- +# IN_PROGRESS replay protection +# --------------------------------------------------------------------------- + + +class TestInProgressReplayProtection: + async def test_in_progress_oauth2_session_raises_400(self): + repo = FakeOAuthInitSessionRepository() + session = _init_session(status=ConnectionInitStatus.IN_PROGRESS) + repo.seed_by_state("state-abc", session) + + svc = _service(init_session_repo=repo) + with pytest.raises(HTTPException) as exc: + await svc.complete_oauth2_callback(DB, state="state-abc", code="c") + assert exc.value.status_code == 400 + + async def test_in_progress_oauth1_session_raises_400(self): + repo = FakeOAuthInitSessionRepository() + session = _init_session(status=ConnectionInitStatus.IN_PROGRESS) + repo.seed_by_oauth_token("tok1", session) + + svc = _service(init_session_repo=repo) + with pytest.raises(HTTPException) as exc: + await svc.complete_oauth1_callback(DB, oauth_token="tok1", oauth_verifier="v") + assert exc.value.status_code == 400 + + +# --------------------------------------------------------------------------- +# Deferred sync when claim token is set +# --------------------------------------------------------------------------- + + +class TestDeferredSync: + async def test_sync_deferred_when_claim_token_set(self): + """_finalize_callback skips Temporal trigger when init session has claim_token_hash.""" + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session( + session_id=init_session_id, + claim_token_hash="abc123", + ) + init_repo.seed_by_id(init_session_id, session) + + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") + builder = AsyncMock() + builder.build_response = AsyncMock(return_value=response) + + temporal_svc = AsyncMock() + temporal_svc.run_source_connection_workflow = AsyncMock() + + event_bus = AsyncMock() + event_bus.publish = AsyncMock() + + sync_id = uuid4() + source_conn = SimpleNamespace( + id=uuid4(), + sync_id=sync_id, + connection_id=uuid4(), + connection_init_session_id=init_session_id, + readable_collection_id="col-abc", + ) + + # Seed sync repo with a pending job + from airweave import schemas + + sync_repo = FakeSyncRepository() + sync_repo.seed( + sync_id, + schemas.Sync( + id=sync_id, + name="test-sync", + source_connection_id=source_conn.connection_id, + collection_id=uuid4(), + collection_readable_id="col-abc", + organization_id=ORG_ID, + created_at=NOW, + modified_at=NOW, + cron_schedule=None, + status="active", + source_connections=[], + destination_connections=[], + destination_connection_ids=[], + ), + ) + + from airweave.models.sync_job import SyncJob + + sync_job_repo = FakeSyncJobRepository() + sync_job_repo.seed_jobs_for_sync( + sync_id, + [ + SyncJob( + id=uuid4(), + sync_id=sync_id, + status=SyncJobStatus.PENDING, + organization_id=ORG_ID, + scheduled=False, + ) + ], + ) + + svc = _service( + init_session_repo=init_repo, + response_builder=builder, + temporal_workflow_service=temporal_svc, + event_bus=event_bus, + sync_repo=sync_repo, + sync_job_repo=sync_job_repo, + ) + + result = await svc._finalize_callback(DB, source_conn, _ctx()) + + assert result is response + # Sync should be deferred — workflow NOT triggered + temporal_svc.run_source_connection_workflow.assert_not_awaited() + + async def test_sync_triggered_when_no_claim_token(self): + """Backward compat: existing behavior preserved for sessions without claim_token_hash.""" + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session(session_id=init_session_id, claim_token_hash=None) + init_repo.seed_by_id(init_session_id, session) + + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") + builder = AsyncMock() + builder.build_response = AsyncMock(return_value=response) + + temporal_svc = AsyncMock() + temporal_svc.run_source_connection_workflow = AsyncMock() + + event_bus = AsyncMock() + event_bus.publish = AsyncMock() + + sync_id = uuid4() + conn_id = uuid4() + source_conn = SimpleNamespace( + id=uuid4(), + sync_id=sync_id, + connection_id=conn_id, + connection_init_session_id=init_session_id, + readable_collection_id="col-abc", + ) + + from airweave import schemas + + sync_repo = FakeSyncRepository() + sync_repo.seed( + sync_id, + schemas.Sync( + id=sync_id, + name="test-sync", + source_connection_id=conn_id, + collection_id=uuid4(), + collection_readable_id="col-abc", + organization_id=ORG_ID, + created_at=NOW, + modified_at=NOW, + cron_schedule=None, + status="active", + source_connections=[], + destination_connections=[], + destination_connection_ids=[], + ), + ) + + from airweave.models.sync_job import SyncJob + + sync_job_repo = FakeSyncJobRepository() + sync_job_repo.seed_jobs_for_sync( + sync_id, + [ + SyncJob( + id=uuid4(), + sync_id=sync_id, + status=SyncJobStatus.PENDING, + organization_id=ORG_ID, + scheduled=False, + ) + ], + ) + + from airweave.models.collection import Collection + + collection_repo = FakeCollectionRepository() + col = Collection( + id=uuid4(), + name="Col", + readable_id="col-abc", + organization_id=ORG_ID, + vector_db_deployment_metadata_id=uuid4(), + ) + col.created_at = NOW + col.modified_at = NOW + collection_repo.seed_readable("col-abc", col) + + from airweave.models.connection import Connection + + connection_repo = FakeConnectionRepository() + connection = Connection( + id=conn_id, + organization_id=ORG_ID, + name="github-conn", + readable_id="conn-github-abc", + short_name="github", + integration_type="source", + status=ConnectionStatus.ACTIVE, + ) + connection.created_at = NOW + connection.modified_at = NOW + connection_repo.seed(conn_id, connection) + + svc = _service( + init_session_repo=init_repo, + response_builder=builder, + temporal_workflow_service=temporal_svc, + event_bus=event_bus, + sync_repo=sync_repo, + sync_job_repo=sync_job_repo, + collection_repo=collection_repo, + connection_repo=connection_repo, + ) + + result = await svc._finalize_callback(DB, source_conn, _ctx()) + + assert result is response + temporal_svc.run_source_connection_workflow.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# verify_oauth_flow +# --------------------------------------------------------------------------- + + +class TestVerifyOAuthFlow: + def _make_claim_token(self) -> tuple[str, str]: + import hashlib + import secrets + + token = secrets.token_urlsafe(32) + token_hash = hashlib.sha256(token.encode()).hexdigest() + return token, token_hash + + async def test_happy_path_triggers_sync(self): + claim_token, claim_hash = self._make_claim_token() + user_id = uuid4() + + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session( + session_id=init_session_id, + status=ConnectionInitStatus.IN_PROGRESS, + claim_token_hash=claim_hash, + initiator_user_id=user_id, + ) + init_repo.seed_by_id(init_session_id, session) + + sc_repo = FakeSourceConnectionRepository() + shell = _source_conn_shell(init_session_id=init_session_id) + shell.connection_id = uuid4() + sc_repo.seed(shell.id, shell) + + org_repo = FakeOrganizationRepository() + org_repo.seed(ORG_ID, _organization()) + + response_obj = MagicMock(id=shell.id, short_name="github", readable_collection_id="col-abc") + builder = AsyncMock() + builder.build_response = AsyncMock(return_value=response_obj) + + event_bus = AsyncMock() + event_bus.publish = AsyncMock() + + svc = _service( + init_session_repo=init_repo, + sc_repo=sc_repo, + organization_repo=org_repo, + response_builder=builder, + event_bus=event_bus, + ) + + # Build a ctx with user_id + ctx = _ctx() + from airweave.schemas.user import User + + ctx.user = User( + id=user_id, + email="test@example.com", + created_at=NOW, + modified_at=NOW, + ) + + result = await svc.verify_oauth_flow( + DB, + source_connection_id=shell.id, + claim_token=claim_token, + ctx=ctx, + ) + + assert result is response_obj + # mark_completed should have been called + assert any(call[0] == "mark_completed" for call in init_repo._calls) + + async def test_wrong_token_raises_403(self): + _, claim_hash = self._make_claim_token() + + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session( + session_id=init_session_id, + status=ConnectionInitStatus.IN_PROGRESS, + claim_token_hash=claim_hash, + ) + init_repo.seed_by_id(init_session_id, session) + + sc_repo = FakeSourceConnectionRepository() + shell = _source_conn_shell(init_session_id=init_session_id) + shell.connection_id = uuid4() + sc_repo.seed(shell.id, shell) + + svc = _service(init_session_repo=init_repo, sc_repo=sc_repo) + + with pytest.raises(HTTPException) as exc: + await svc.verify_oauth_flow( + DB, + source_connection_id=shell.id, + claim_token="wrong-token", + ctx=_ctx(), + ) + assert exc.value.status_code == 403 + assert "Invalid claim token" in exc.value.detail + + async def test_wrong_user_raises_403(self): + claim_token, claim_hash = self._make_claim_token() + initiator_user_id = uuid4() + different_user_id = uuid4() + + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session( + session_id=init_session_id, + status=ConnectionInitStatus.IN_PROGRESS, + claim_token_hash=claim_hash, + initiator_user_id=initiator_user_id, + ) + init_repo.seed_by_id(init_session_id, session) + + sc_repo = FakeSourceConnectionRepository() + shell = _source_conn_shell(init_session_id=init_session_id) + shell.connection_id = uuid4() + sc_repo.seed(shell.id, shell) + + svc = _service(init_session_repo=init_repo, sc_repo=sc_repo) + + ctx = _ctx() + from airweave.schemas.user import User + + ctx.user = User( + id=different_user_id, + email="other@example.com", + created_at=NOW, + modified_at=NOW, + ) + + with pytest.raises(HTTPException) as exc: + await svc.verify_oauth_flow( + DB, + source_connection_id=shell.id, + claim_token=claim_token, + ctx=ctx, + ) + assert exc.value.status_code == 403 + assert "identity mismatch" in exc.value.detail.lower() + + async def test_already_completed_raises_400(self): + claim_token, claim_hash = self._make_claim_token() + + init_session_id = uuid4() + init_repo = FakeOAuthInitSessionRepository() + session = _init_session( + session_id=init_session_id, + status=ConnectionInitStatus.COMPLETED, + claim_token_hash=claim_hash, + ) + init_repo.seed_by_id(init_session_id, session) + + sc_repo = FakeSourceConnectionRepository() + shell = _source_conn_shell(init_session_id=init_session_id) + shell.connection_id = uuid4() + sc_repo.seed(shell.id, shell) + + svc = _service(init_session_repo=init_repo, sc_repo=sc_repo) + + with pytest.raises(HTTPException) as exc: + await svc.verify_oauth_flow( + DB, + source_connection_id=shell.id, + claim_token=claim_token, + ctx=_ctx(), + ) + assert exc.value.status_code == 400 + assert "completed" in exc.value.detail.lower() diff --git a/backend/airweave/domains/source_connections/create.py b/backend/airweave/domains/source_connections/create.py index a254b472d..ee3105cca 100644 --- a/backend/airweave/domains/source_connections/create.py +++ b/backend/airweave/domains/source_connections/create.py @@ -1,5 +1,6 @@ """Source connection creation service.""" +import hashlib import secrets from typing import Any, Optional from uuid import UUID @@ -358,6 +359,11 @@ async def _create_with_oauth_browser( except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc state = secrets.token_urlsafe(24) + claim_token = secrets.token_urlsafe(32) + claim_token_hash = hashlib.sha256(claim_token.encode()).hexdigest() + + initiator_user_id = ctx.user_id if hasattr(ctx, "user_id") else None + initiator_session_id: Optional[UUID] = getattr(ctx, "session_id", None) initiation_result = await self._oauth_flow_service.initiate_browser_flow( short_name=obj_in.short_name, @@ -425,6 +431,9 @@ async def _create_with_oauth_browser( redirect_url=obj_in.redirect_url, template_configs=template_configs, additional_overrides=initiation_result.additional_overrides, + initiator_user_id=initiator_user_id, + initiator_session_id=initiator_session_id, + claim_token_hash=claim_token_hash, ) await uow.session.flush() source_conn.connection_init_session_id = init_session.id @@ -433,7 +442,9 @@ async def _create_with_oauth_browser( await uow.commit() await uow.session.refresh(source_conn) - return await self._response_builder.build_response(db, source_conn, ctx) + return await self._response_builder.build_response( + db, source_conn, ctx, claim_token=claim_token + ) async def _create_redirect_session( self, db: AsyncSession, provider_auth_url: str, ctx: ApiContext, uow: UnitOfWork diff --git a/backend/airweave/domains/source_connections/fakes/response.py b/backend/airweave/domains/source_connections/fakes/response.py index e09599b9d..4e4ac6b89 100644 --- a/backend/airweave/domains/source_connections/fakes/response.py +++ b/backend/airweave/domains/source_connections/fakes/response.py @@ -35,7 +35,13 @@ def __init__(self, should_raise: Optional[Exception] = None) -> None: self._should_raise = should_raise async def build_response( - self, db: AsyncSession, source_conn: SourceConnection, ctx: ApiContext + self, + db: AsyncSession, + source_conn: SourceConnection, + ctx: ApiContext, + *, + claim_token: str | None = None, + **kwargs, ) -> SourceConnectionSchema: """Build a minimal SourceConnection from source_conn attributes.""" if self._should_raise: @@ -55,6 +61,7 @@ async def build_response( auth=AuthenticationDetails( method=AuthenticationMethod.DIRECT, authenticated=getattr(source_conn, "is_authenticated", True), + claim_token=claim_token, ), ) diff --git a/backend/airweave/domains/source_connections/protocols.py b/backend/airweave/domains/source_connections/protocols.py index f96f35b9d..217f72160 100644 --- a/backend/airweave/domains/source_connections/protocols.py +++ b/backend/airweave/domains/source_connections/protocols.py @@ -144,6 +144,7 @@ async def build_response( *, auth_url_override: Optional[str] = None, auth_url_expiry_override: Optional[datetime] = None, + claim_token: Optional[str] = None, ) -> SourceConnectionSchema: """Build full SourceConnection response from ORM object.""" ... diff --git a/backend/airweave/domains/source_connections/response.py b/backend/airweave/domains/source_connections/response.py index c2e0f97f7..c460708d9 100644 --- a/backend/airweave/domains/source_connections/response.py +++ b/backend/airweave/domains/source_connections/response.py @@ -71,6 +71,7 @@ async def build_response( *, auth_url_override: Optional[str] = None, auth_url_expiry_override: Optional[datetime] = None, + claim_token: Optional[str] = None, ) -> SourceConnectionSchema: """Build complete SourceConnection response from an ORM object.""" auth = await self._build_auth_details( @@ -79,6 +80,7 @@ async def build_response( ctx, auth_url_override=auth_url_override, auth_url_expiry_override=auth_url_expiry_override, + claim_token=claim_token, ) schedule = await self._build_schedule_details(db, source_conn, ctx) sync_details = await self._build_sync_details(db, source_conn, ctx) @@ -159,6 +161,7 @@ async def _build_auth_details( *, auth_url_override: Optional[str] = None, auth_url_expiry_override: Optional[datetime] = None, + claim_token: Optional[str] = None, ) -> AuthenticationDetails: """Build authentication details section.""" actual_auth_method = await self._resolve_auth_method(db, source_conn, ctx) @@ -193,6 +196,7 @@ async def _build_auth_details( auth_url=auth_url, auth_url_expires=auth_url_expires, redirect_url=redirect_url, + claim_token=claim_token, ) async def _resolve_auth_method( diff --git a/backend/airweave/models/connection_init_session.py b/backend/airweave/models/connection_init_session.py index 7b19306e3..a4dc3dd2a 100644 --- a/backend/airweave/models/connection_init_session.py +++ b/backend/airweave/models/connection_init_session.py @@ -18,6 +18,7 @@ class ConnectionInitStatus: """String constants representing ConnectionInitSession lifecycle states.""" PENDING = "pending" + IN_PROGRESS = "in_progress" COMPLETED = "completed" EXPIRED = "expired" CANCELLED = "cancelled" @@ -54,6 +55,13 @@ class ConnectionInitSession(OrganizationBase): # Expiration for security; default TTL ~5 minutes can be applied at creation expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + # Caller-identity binding for claim-token verification + initiator_user_id: Mapped[Optional[UUID]] = mapped_column( + ForeignKey("user.id", ondelete="SET NULL"), nullable=True + ) + initiator_session_id: Mapped[Optional[UUID]] = mapped_column(nullable=True) + claim_token_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + # Set when finalized (optional) final_connection_id: Mapped[Optional[UUID]] = mapped_column( ForeignKey("connection.id", ondelete="SET NULL"), nullable=True diff --git a/backend/airweave/schemas/__init__.py b/backend/airweave/schemas/__init__.py index 26c434b64..50f102bdf 100644 --- a/backend/airweave/schemas/__init__.py +++ b/backend/airweave/schemas/__init__.py @@ -104,6 +104,7 @@ SourceConnectionUpdate, SyncDetails, SyncJobDetails, + VerifyOAuthRequest, ) from .source_rate_limit import ( SourceRateLimit, diff --git a/backend/airweave/schemas/source_connection.py b/backend/airweave/schemas/source_connection.py index 54f04c2bb..bf507829f 100644 --- a/backend/airweave/schemas/source_connection.py +++ b/backend/airweave/schemas/source_connection.py @@ -496,6 +496,11 @@ class AuthenticationDetails(BaseModel): auth_url: Optional[str] = Field(None, description="For pending OAuth flows") auth_url_expires: Optional[datetime] = None redirect_url: Optional[str] = None + claim_token: Optional[str] = Field( + None, + description="One-time token to verify OAuth flow ownership. " + "Only returned when creating an OAuth browser connection.", + ) # Provider-specific provider_readable_id: Optional[str] = None @@ -831,6 +836,12 @@ def determine_auth_method(source_conn: Any) -> AuthenticationMethod: return AuthenticationMethod.OAUTH_BROWSER +class VerifyOAuthRequest(BaseModel): + """Request body for verifying OAuth flow ownership.""" + + claim_token: str = Field(..., description="Claim token from create response") + + def compute_status( source_conn: Any, last_job_status: Optional[SyncJobStatus] = None, diff --git a/backend/alembic/versions/8bdd5dcf7837_add_oauth_claim_token_fields.py b/backend/alembic/versions/8bdd5dcf7837_add_oauth_claim_token_fields.py new file mode 100644 index 000000000..c8b1b626a --- /dev/null +++ b/backend/alembic/versions/8bdd5dcf7837_add_oauth_claim_token_fields.py @@ -0,0 +1,43 @@ +"""add oauth claim token fields + +Revision ID: 8bdd5dcf7837 +Revises: a1b2c3d4e5f7 +Create Date: 2026-03-13 13:28:16.019661 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "8bdd5dcf7837" +down_revision = "a1b2c3d4e5f7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "connection_init_session", + sa.Column( + "initiator_user_id", + sa.Uuid(), + sa.ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ), + ) + op.add_column( + "connection_init_session", + sa.Column("initiator_session_id", sa.Uuid(), nullable=True), + ) + op.add_column( + "connection_init_session", + sa.Column("claim_token_hash", sa.String(64), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("connection_init_session", "claim_token_hash") + op.drop_column("connection_init_session", "initiator_session_id") + op.drop_column("connection_init_session", "initiator_user_id") diff --git a/connect/src/lib/api.ts b/connect/src/lib/api.ts index 5b339361a..b2ad2f8cc 100644 --- a/connect/src/lib/api.ts +++ b/connect/src/lib/api.ts @@ -102,6 +102,19 @@ class ConnectApiClient { ); } + async verifyOAuth( + connectionId: string, + claimToken: string, + ): Promise { + return this.fetch( + `/connect/source-connections/${connectionId}/verify-oauth`, + { + method: "POST", + body: JSON.stringify({ claim_token: claimToken }), + }, + ); + } + async getConnectionJobs( connectionId: string, ): Promise { diff --git a/connect/src/lib/types/authentication.ts b/connect/src/lib/types/authentication.ts index fd56838ce..a633e55b7 100644 --- a/connect/src/lib/types/authentication.ts +++ b/connect/src/lib/types/authentication.ts @@ -35,5 +35,6 @@ export interface SourceConnectionCreateResponse { method: AuthenticationMethod; authenticated: boolean; auth_url?: string; + claim_token?: string; }; } diff --git a/connect/src/lib/useOAuthFlow.ts b/connect/src/lib/useOAuthFlow.ts index db9e53078..b10536a45 100644 --- a/connect/src/lib/useOAuthFlow.ts +++ b/connect/src/lib/useOAuthFlow.ts @@ -48,9 +48,11 @@ export function useOAuthFlow({ const oauthCompletedRef = useRef(false); // Track the connection ID created in this OAuth flow for cleanup on cancel const createdConnectionIdRef = useRef(null); + // Store claim token for verify-oauth call after popup completes + const claimTokenRef = useRef(null); const handleOAuthResult = useCallback( - (result: OAuthCallbackResult) => { + async (result: OAuthCallbackResult) => { // Mark OAuth as completed synchronously to prevent race condition with interval oauthCompletedRef.current = true; @@ -60,6 +62,23 @@ export function useOAuthFlow({ popupRef.current = null; if (result.status === "success" && result.source_connection_id) { + try { + if (claimTokenRef.current) { + await apiClient.verifyOAuth( + result.source_connection_id, + claimTokenRef.current, + ); + claimTokenRef.current = null; + } + } catch (err) { + setStatus("error"); + setError( + err instanceof Error + ? err.message + : "Failed to verify OAuth flow ownership", + ); + return; + } // Clear the ref since OAuth succeeded - don't delete the connection createdConnectionIdRef.current = null; setStatus("idle"); @@ -138,6 +157,7 @@ export function useOAuthFlow({ const response = await apiClient.createSourceConnection(payload); // Track the created connection for cleanup if user cancels createdConnectionIdRef.current = response.id; + claimTokenRef.current = response.auth?.claim_token ?? null; if (response.auth?.auth_url) { setStatus("waiting"); diff --git a/fern/docs/pages/connecting-sources/direct-oauth.mdx b/fern/docs/pages/connecting-sources/direct-oauth.mdx index 930229a3d..4d5df769b 100644 --- a/fern/docs/pages/connecting-sources/direct-oauth.mdx +++ b/fern/docs/pages/connecting-sources/direct-oauth.mdx @@ -46,10 +46,10 @@ The URL-based OAuth flow is the simplest way to connect sources. Users are redir ### How it works -1. **Initiate Connection**: Create a source connection using the URL-based authentication method +1. **Initiate Connection**: Create a source connection using the URL-based authentication method. The response includes a `claim_token` that you must store for the verification step. 2. **User Consent**: Users are redirected to the service provider's consent screen 3. **Token Exchange**: Airweave exchanges the authorization code for access and refresh tokens -4. **Data Sync**: The connection is established and data synchronization begins +4. **Verify Ownership**: After the OAuth callback completes, call `POST /source-connections/{id}/verify-oauth` with the `claim_token` to prove your client initiated the flow. This triggers data synchronization. ### Creating a URL-based OAuth connection @@ -131,9 +131,41 @@ curl -X POST https://api.airweave.ai/source-connections \ When using URL-based OAuth, you'll receive an authorization URL that you need to redirect users to. After they complete the OAuth flow, they'll be redirected back to your specified callback URL with the necessary parameters. -**Important**: OAuth browser flows (both standard and BYOC) cannot use `sync_immediately=true`. The sync will automatically start after the user completes the OAuth authentication flow. Setting `sync_immediately=true` will result in a validation error. +**Important**: OAuth browser flows (both standard and BYOC) cannot use `sync_immediately=true`. After the user completes the OAuth consent screen, you must call the `verify-oauth` endpoint with the `claim_token` from the create response to trigger data synchronization. Setting `sync_immediately=true` will result in a validation error. +### Verifying the OAuth flow + +After the OAuth callback completes and the user is redirected back to your application, call the verify-oauth endpoint with the `claim_token` received during connection creation. This proves your client initiated the flow and triggers the data sync. + + +```Python title="Python" +# After user completes OAuth and is redirected back: +airweave.source_connections.verify_oauth( + id=source_connection.id, + claim_token=source_connection.auth.claim_token, +) +``` + +```javascript title="Node.js" +// After user completes OAuth and is redirected back: +await airweave.sourceConnections.verifyOauth( + sourceConnection.id, + { claimToken: sourceConnection.auth.claimToken } +); +``` + +```bash title="cURL" +# After user completes OAuth and is redirected back: +curl -X POST https://api.airweave.ai/source-connections/{connection-id}/verify-oauth \ + -H "x-api-key: " \ + -H "Content-Type: application/json" \ + -d '{ + "claim_token": "" +}' +``` + + ### Redirect URI basics The `redirect_uri` is the URL where the user is sent after granting consent with the provider. It must exactly match an allowed redirect/callback URL configured for the OAuth app. Use it when you want users to return to your application after authentication (typical for hosted OAuth). For BYOC, configure the provider to allow Airweave's callback (`https://api.airweave.ai/oauth/callback`) and optionally set `redirect_uri` so Airweave can forward the user back to your app after the token exchange.