Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cursor/rules/api-layer.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ api/
- For direct auth: validates credential format
- For OAuth: handles authorization flows
- For auth providers: validates provider exists and supports the source
- **POST /source-connections/{id}/verify-oauth**: Verifies claim-token ownership and triggers deferred sync. Called after the OAuth callback completes to prove the caller that initiated the flow is the one completing it. Required for all browser-based OAuth flows.

## Core Components

Expand Down
16 changes: 15 additions & 1 deletion .cursor/rules/connect-widget.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ requestClose(reason: "success" | "cancel" | "error");
```

### 3. OAuth Flow Security
OAuth uses same-origin popups with validated messaging:
OAuth uses same-origin popups with validated messaging and claim-token verification:
```typescript
// oauth-callback.tsx posts to same origin
window.opener.postMessage({ type: "OAUTH_COMPLETE", ...result }, window.location.origin);
Expand All @@ -121,6 +121,20 @@ const handler = (event: MessageEvent) => {
};
```

**Claim-token verification (two-step completion):**
After the OAuth popup completes, the Connect widget calls `verifyOAuth` with the claim token before considering the flow complete. The claim token is returned by the initial `createSourceConnection` call and stored in `claimTokenRef`. This ensures the caller that initiated the OAuth flow is the same one completing it.

See `useOAuthFlow.ts` lines 66-72 for the implementation:
```typescript
if (claimTokenRef.current) {
await apiClient.verifyOAuth(
result.source_connection_id,
claimTokenRef.current,
);
claimTokenRef.current = null;
}
```

### 4. Theming System
Fully customizable via CSS variables passed from parent:
```typescript
Expand Down
19 changes: 19 additions & 0 deletions backend/airweave/api/v1/endpoints/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ConnectSessionCreate,
ConnectSessionResponse,
)
from airweave.schemas.source_connection import VerifyOAuthRequest

router = TrailingSlashRouter()

Expand Down Expand Up @@ -174,6 +175,24 @@ async def create_source_connection(
return await svc.create_source_connection(db, source_connection_in, session, session_token)


@router.post(
"/source-connections/{connection_id}/verify-oauth",
response_model=schemas.SourceConnection,
)
async def verify_oauth(
connection_id: UUID,
body: VerifyOAuthRequest,
db: AsyncSession = Depends(get_db),
session: ConnectSessionContext = Depends(deps.get_connect_session),
svc: ConnectServiceProtocol = Inject(ConnectServiceProtocol),
) -> schemas.SourceConnection:
"""Verify OAuth flow ownership via Connect session.

Authentication: Bearer <session_token>
"""
return await svc.verify_oauth(db, connection_id, body.claim_token, session)


# =============================================================================
# Sync Jobs
# =============================================================================
Expand Down
30 changes: 30 additions & 0 deletions backend/airweave/api/v1/endpoints/source_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
RateLimitErrorResponse,
ValidationErrorResponse,
)
from airweave.schemas.source_connection import VerifyOAuthRequest

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,6 +98,35 @@ async def oauth_callback(
)


@router.post(
"/{source_connection_id}/verify-oauth",
response_model=schemas.SourceConnection,
summary="Verify OAuth Flow",
description="""Verify ownership of an OAuth flow by presenting the claim token
returned during connection creation. This triggers the deferred sync.""",
responses={
200: {"model": schemas.SourceConnection, "description": "Verified source connection"},
403: {"description": "Invalid claim token or identity mismatch"},
404: {"model": NotFoundErrorResponse, "description": "Source Connection Not Found"},
},
)
async def verify_oauth(
*,
db: AsyncSession = Depends(get_db),
source_connection_id: UUID = Path(...),
body: VerifyOAuthRequest,
ctx: ApiContext = Depends(deps.get_context),
oauth_callback_svc: OAuthCallbackServiceProtocol = Inject(OAuthCallbackServiceProtocol),
) -> schemas.SourceConnection:
"""Verify OAuth flow ownership and trigger deferred sync."""
return await oauth_callback_svc.verify_oauth_flow(
db,
source_connection_id=source_connection_id,
claim_token=body.claim_token,
ctx=ctx,
)


@router.post(
"/",
response_model=schemas.SourceConnection,
Expand Down
29 changes: 15 additions & 14 deletions backend/airweave/core/container/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,6 @@ def create_container(settings: Settings) -> Container:
deletion_service=deletion_service,
)

# -----------------------------------------------------------------
# Connect domain service
# -----------------------------------------------------------------
from airweave.domains.connect.service import ConnectService
from airweave.domains.organizations.repository import OrganizationRepository as ConnectOrgRepo

connect_service = ConnectService(
source_connection_service=source_connection_service,
source_service=source_deps["source_service"],
org_repo=ConnectOrgRepo(),
collection_repo=source_deps["collection_repo"],
sync_job_repo=source_deps["sync_job_repo"],
)

# -----------------------------------------------------------------
# Embedder registries + instances (deployment-wide singletons)
# -----------------------------------------------------------------
Expand Down Expand Up @@ -410,6 +396,21 @@ def create_container(settings: Settings) -> Container:
credential_encryptor=encryptor,
)

# -----------------------------------------------------------------
# Connect domain service (after oauth_callback_svc for DI)
# -----------------------------------------------------------------
from airweave.domains.connect.service import ConnectService
from airweave.domains.organizations.repository import OrganizationRepository as ConnectOrgRepo

connect_service = ConnectService(
source_connection_service=source_connection_service,
source_service=source_deps["source_service"],
org_repo=ConnectOrgRepo(),
collection_repo=source_deps["collection_repo"],
sync_job_repo=source_deps["sync_job_repo"],
oauth_callback_service=oauth_callback_svc,
)

# -----------------------------------------------------------------
# Browse tree service
# -----------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions backend/airweave/domains/connect/fakes/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ async def create_source_connection(
)
raise NotImplementedError("FakeConnectService.create_source_connection not seeded")

async def verify_oauth(
self,
db: AsyncSession,
connection_id: UUID,
claim_token: str,
session: ConnectSessionContext,
) -> Any:
self._calls.append(("verify_oauth", connection_id, claim_token, session))
raise NotImplementedError("FakeConnectService.verify_oauth not seeded")

async def get_connection_jobs(
self,
db: AsyncSession,
Expand Down
10 changes: 10 additions & 0 deletions backend/airweave/domains/connect/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,18 @@
session_token: str,
) -> schemas.SourceConnection:
"""Create a source connection via Connect session."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

async def verify_oauth(
self,
db: AsyncSession,
connection_id: UUID,
claim_token: str,
session: ConnectSessionContext,
) -> schemas.SourceConnection:
"""Verify OAuth flow ownership via Connect session."""
...

async def get_connection_jobs(
self,
db: AsyncSession,
Expand Down
21 changes: 21 additions & 0 deletions backend/airweave/domains/connect/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from airweave.domains.collections.protocols import CollectionRepositoryProtocol
from airweave.domains.connect.protocols import ConnectServiceProtocol
from airweave.domains.connect.types import MODES_CREATE, MODES_DELETE, MODES_VIEW
from airweave.domains.oauth.protocols import OAuthCallbackServiceProtocol
from airweave.domains.organizations.protocols import OrganizationRepositoryProtocol
from airweave.domains.source_connections.protocols import SourceConnectionServiceProtocol
from airweave.domains.sources.protocols import SourceServiceProtocol
Expand All @@ -44,12 +45,14 @@ def __init__( # noqa: D107
org_repo: OrganizationRepositoryProtocol,
collection_repo: CollectionRepositoryProtocol,
sync_job_repo: SyncJobRepositoryProtocol,
oauth_callback_service: OAuthCallbackServiceProtocol,
) -> None:
self._sc_service = source_connection_service
self._source_service = source_service
self._org_repo = org_repo
self._collection_repo = collection_repo
self._sync_job_repo = sync_job_repo
self._oauth_callback_service = oauth_callback_service

# ------------------------------------------------------------------
# Guards (private — all access checks live on the service)
Expand Down Expand Up @@ -346,6 +349,24 @@ async def create_source_connection(

return result # type: ignore[return-value]

async def verify_oauth(
self,
db: AsyncSession,
connection_id: UUID,
claim_token: str,
session: ConnectSessionContext,
) -> schemas.SourceConnection:
"""Verify OAuth flow ownership via Connect session."""
self._check_mode(session, MODES_CREATE, "verifying OAuth flow")
ctx = await self._build_context(db, session)

return await self._oauth_callback_service.verify_oauth_flow(
db,
source_connection_id=connection_id,
claim_token=claim_token,
ctx=ctx,
)

# ------------------------------------------------------------------
# Sync jobs
# ------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions backend/airweave/domains/connect/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ def sync_job_repo():

@pytest.fixture
def connect_service(org_repo, sc_service, source_service, collection_repo, sync_job_repo):
from unittest.mock import AsyncMock

return ConnectService(
source_connection_service=sc_service,
source_service=source_service,
org_repo=org_repo,
collection_repo=collection_repo,
sync_job_repo=sync_job_repo,
oauth_callback_service=AsyncMock(),
)
Loading
Loading