diff --git a/.cursor/rules/auth-providers.mdc b/.cursor/rules/auth-providers.mdc index 8289be182..407cde2b0 100644 --- a/.cursor/rules/auth-providers.mdc +++ b/.cursor/rules/auth-providers.mdc @@ -50,51 +50,46 @@ SLUG_NAME_MAPPING = { ``` -## Integration with TokenManager +## Integration with TokenProviderProtocol -The `TokenManager` class orchestrates token refresh during long-running syncs: +Auth providers integrate via `AuthProviderTokenProvider`, one of three `TokenProviderProtocol` implementations: -### Initialization +### TokenProviderProtocol ```python -# SyncFactory creates TokenManager with optional auth provider -token_manager = TokenManager( - db=db, - source_connection=connection, - auth_provider_instance=auth_provider, # Optional - initial_credentials=credentials +class TokenProviderProtocol(Protocol): + async def get_token(self) -> str: ... + async def force_refresh(self) -> str: ... +``` + +### AuthProviderTokenProvider +Created by `SourceLifecycleService._configure_token_provider()` when a source connection has an `auth_provider_connection_id`. + +```python +# Lifecycle builds the provider with the auth provider instance +provider = AuthProviderTokenProvider( + auth_provider=auth_provider_instance, + source_short_name="slack", + auth_config_fields=["access_token", "refresh_token"], + logger=logger, ) ``` -### Refresh Flow -1. **Check refresh capability** (`_determine_refresh_capability()`): - - Direct injection tokens → no refresh - - Auth provider present → always refreshable - - Standard OAuth → attempt refresh - -2. **Proactive refresh** (`get_valid_token()`): - - Refreshes tokens every 25 minutes (before 1-hour expiry) - - Uses async lock to prevent concurrent refreshes - - Falls back to stored token if refresh fails - -3. **Auth provider refresh** (`_refresh_via_auth_provider()`): - ```python - # TokenManager calls auth provider for fresh credentials - fresh_creds = await auth_provider.get_creds_for_source( - source_short_name="slack", - source_auth_config_fields=["access_token", "refresh_token"] - ) - # Updates database with new credentials - await crud.integration_credential.update(db, credential, fresh_creds) - ``` - -4. **Fallback OAuth refresh** (`_refresh_via_oauth()`): - - Uses oauth2_service if no auth provider - - Creates separate DB session to avoid transaction issues - -### Credential Priority (in SyncFactory) -1. Direct token injection (highest) -2. Auth provider instance -3. Database credentials with OAuth refresh +#### Refresh Flow +1. **`get_token()`** — calls `auth_provider.get_creds_for_source()` each time to fetch fresh credentials from the external service (Pipedream/Composio) +2. **`force_refresh()`** — same as `get_token()` (auth providers always return fresh creds) +3. **Retry** — uses tenacity (3 attempts, exponential backoff) on `AuthProviderServerError` / `AuthProviderRateLimitError` +4. **Error translation** — `AuthProviderError` subtypes are mapped to `TokenProviderError` subtypes: + - `AuthProviderAuthError` → `TokenCredentialsInvalidError` + - `AuthProviderAccountNotFoundError` → `TokenProviderAccountGoneError` + - `AuthProviderMissingFieldsError` → `TokenProviderMissingCredsError` + - `AuthProviderConfigError` → `TokenProviderConfigError` + - `AuthProviderRateLimitError` → `TokenProviderRateLimitError` + - `AuthProviderServerError` → `TokenProviderServerError` + +### Provider Resolution (in SourceLifecycleService) +1. Auth provider connection present → `AuthProviderTokenProvider` +2. OAuth credentials with `oauth_type` → `OAuthTokenProvider` +3. Direct token injection → `StaticTokenProvider` ## Database Schema - **auth_providers**: Provider definitions from decorators diff --git a/.cursor/rules/source-connector-implementation.mdc b/.cursor/rules/source-connector-implementation.mdc index f89e92fcf..4902777ca 100644 --- a/.cursor/rules/source-connector-implementation.mdc +++ b/.cursor/rules/source-connector-implementation.mdc @@ -403,7 +403,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional import httpx from tenacity import retry, stop_after_attempt, wait_exponential -from airweave.core.exceptions import TokenRefreshError +from airweave.domains.sources.exceptions import SourceAuthError, SourceRateLimitError, SourceServerError from airweave.platform.decorators import source from airweave.platform.entities._base import Breadcrumb, ChunkEntity from airweave.platform.entities.{short_name} import ( @@ -436,21 +436,24 @@ class MyConnectorSource(BaseSource): @classmethod async def create( - cls, access_token: str, config: Optional[Dict[str, Any]] = None + cls, + *, + credentials: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> "MyConnectorSource": """Create and configure the source. Args: - access_token: OAuth access token or API key + credentials: Decrypted auth credentials (access_token, refresh_token, etc.) config: Optional configuration (e.g., workspace filters) + **kwargs: Additional keyword arguments (token_provider, logger, etc.) Returns: Configured source instance """ instance = cls() - instance.access_token = access_token - # Store config as instance attributes if config: instance.workspace_id = config.get("workspace_id") instance.exclude_pattern = config.get("exclude_pattern", "") @@ -464,13 +467,12 @@ class MyConnectorSource(BaseSource): """Generate all entities from the source. This is the main entry point called by the sync engine. + Token is obtained via self.get_access_token() (delegates to the injected TokenProvider). """ async with self.http_client() as client: - # Generate entities hierarchically async for top_level in self._generate_top_level(client): yield top_level - # Generate children with breadcrumb tracking async for child in self._generate_children(client, top_level): yield child @@ -491,20 +493,21 @@ class MyConnectorSource(BaseSource): #### 1. The `create()` Classmethod -This is called once when a sync starts: +This is called once when a sync starts. The `SourceLifecycleService` passes keyword arguments: ```python @classmethod async def create( - cls, access_token: str, config: Optional[Dict[str, Any]] = None + cls, + *, + credentials: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> "MyConnectorSource": """Create and configure the source.""" instance = cls() - instance.access_token = access_token - # Parse config fields if config: - # Store as instance attributes for use in generate_entities() instance.workspace_filter = config.get("workspace_filter", "") instance.include_archived = config.get("include_archived", False) else: @@ -514,6 +517,8 @@ async def create( return instance ``` +**Note:** Do NOT store `access_token` as an instance attribute. Tokens are managed by the injected `TokenProvider` and accessed via `self.get_access_token()`. + #### 2. The `generate_entities()` Method This is an async generator that yields entities: @@ -536,7 +541,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]: workspace_breadcrumb = Breadcrumb( entity_id=workspace.entity_id, name=workspace.name, - type="workspace" + entity_type="WorkspaceEntity", ) # Child entities @@ -546,7 +551,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]: project_breadcrumb = Breadcrumb( entity_id=project.entity_id, name=project.name, - type="project" + entity_type="ProjectEntity", ) breadcrumbs = [workspace_breadcrumb, project_breadcrumb] @@ -557,7 +562,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]: #### 3. Making API Requests with Token Refresh -Always use this pattern for authenticated requests: +Use `self.get_access_token()` (delegates to the injected `TokenProvider`) and `self.refresh_on_unauthorized()` for 401 recovery: ```python @retry( @@ -571,55 +576,41 @@ async def _get_with_auth( url: str, params: Optional[Dict[str, Any]] = None ) -> Dict: - """Make authenticated GET request with automatic token refresh. - - This method handles: - - Token refresh on 401 errors - - Retries with exponential backoff - - Proper error logging - """ - # Get a valid token (will refresh if needed) + """Make authenticated GET request with automatic token refresh.""" access_token = await self.get_access_token() if not access_token: - raise ValueError("No access token available") + raise SourceAuthError(source=self.short_name, message="No access token available") headers = {"Authorization": f"Bearer {access_token}"} try: response = await client.get(url, headers=headers, params=params) - # Handle 401 Unauthorized - token might have expired if response.status_code == 401: - self.logger.warning(f"Received 401 for {url}, refreshing token...") - - if self.token_manager: - try: - # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() - headers = {"Authorization": f"Bearer {new_token}"} - - # Retry with new token - self.logger.info(f"Retrying with refreshed token: {url}") - response = await client.get(url, headers=headers, params=params) - - except TokenRefreshError as e: - self.logger.error(f"Failed to refresh token: {str(e)}") - response.raise_for_status() - else: - self.logger.error("No token manager available") - response.raise_for_status() + self.logger.warning(f"Received 401 for {url}, forcing token refresh...") + new_token = await self.refresh_on_unauthorized() + headers = {"Authorization": f"Bearer {new_token}"} + response = await client.get(url, headers=headers, params=params) + + if response.status_code == 429: + retry_after = response.headers.get("Retry-After") + raise SourceRateLimitError( + source=self.short_name, + retry_after=int(retry_after) if retry_after else None, + ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: + if e.response.status_code >= 500: + raise SourceServerError(source=self.short_name, status_code=e.response.status_code) self.logger.error(f"HTTP error: {e.response.status_code} for {url}") raise - except Exception as e: - self.logger.error(f"Unexpected error: {url}, {str(e)}") - raise ``` +**Note:** `self.refresh_on_unauthorized()` calls `self._token_provider.force_refresh()` under the hood. If the token provider doesn't support refresh (e.g., `StaticTokenProvider`), it raises `TokenRefreshNotSupportedError`. + ### Handling Hierarchical Data Use breadcrumbs to track entity relationships: @@ -1049,7 +1040,7 @@ async for task in self._generate_tasks(client, project, breadcrumbs): task_breadcrumb = Breadcrumb( entity_id=task.entity_id, name=task.name, - type="task" + entity_type="TaskEntity", ) task_breadcrumbs = [*breadcrumbs, task_breadcrumb] @@ -1170,7 +1161,7 @@ async def _generate_projects(self, client, workspace): - [ ] Auth config class added to `platform/configs/auth.py` - [ ] Auth config referenced in source `@source` decorator - [ ] Source implements `create()`, `generate_entities()`, and `validate()` -- [ ] Token refresh is handled via `_get_with_auth()` pattern +- [ ] Token refresh handled via `self.get_access_token()` + `self.refresh_on_unauthorized()` pattern - [ ] File entities use `process_file_entity()` - [ ] Logging uses proper levels (INFO for milestones, DEBUG for details) - [ ] OAuth config is in `dev.integrations.yaml` (human already set this up) diff --git a/.cursor/rules/source-contract-redesign.mdc b/.cursor/rules/source-contract-redesign.mdc new file mode 100644 index 000000000..2004f3948 --- /dev/null +++ b/.cursor/rules/source-contract-redesign.mdc @@ -0,0 +1,187 @@ +--- +description: Source contract v2 redesign — handover context for continuing the refactor +alwaysApply: false +--- + +# Source Contract v2 Redesign — Handover + +## The problem + +The `BaseSource` contract was designed 14 months ago for OSS contributors to build sources easily. The system grew complex, but the contract didn't evolve. Result: + +- **Setter soup**: 7 `set_*()` calls across 2 files after construction. Source is never "fully configured" at any point. +- **Untyped creation**: `create(credentials: Optional[Any], config: Optional[Dict])` — nobody knows what credentials actually is. +- **No exceptions**: Sources catch bare `httpx.HTTPStatusError`, pipeline can't distinguish "skip entity" from "abort sync". +- **Token tangle**: `self.access_token` dynamic attribute + `TokenManager` + `get_access_token()` fallback with `getattr`. +- **Hidden deps**: `generate_entities()` takes no params but relies on cursor, file_downloader, node_selections stored on self. +- **God class**: BaseSource is 800 lines mixing auth, HTTP, content cleaning, concurrency, and OAuth validation. + +## The target contract + +```python +class BaseSource: + # Construction: typed deps, no setters + @classmethod + async def create(cls, *, auth: TokenProvider | AuthConfig, config: Optional[BaseModel], + logger: ContextualLogger) -> "BaseSource": ... + + # Methods carry their operation context + async def generate_entities(self, *, cursor: SyncCursor, files: FileService) -> AsyncGenerator[BaseEntity, None]: ... + async def validate(self) -> bool: ... + async def search(self, *, query: str, limit: int) -> AsyncGenerator[BaseEntity, None]: ... + async def get_browse_children(self, *, parent_node_id: str | None = None) -> list[BrowseNode]: ... + async def execute_tool(self, *, tool_name: str, parameters: dict) -> ToolResult: ... # future +``` + +One factory (`SourceLifecycleService`) for ALL creation contexts: sync, search, browse, validation, tool calls. + +## What's done + +### Exception hierarchy (`domains/sources/exceptions.py`) +Full typed hierarchy under `SourceError(AirweaveException)`: +- `SourceAuthError` > `SourceTokenRefreshError` — 401 / refresh failure +- `SourceRateLimitError(retry_after)` — 429 +- `SourceServerError(status_code)` — 5xx, timeout, connection errors (canonical upstream-failure exception) + - `SourceTemporaryError` / `SourcePermanentError` are **aliases** for `SourceServerError` (backward compat) +- `SourceEntityError(entity_id)` > `SourceEntityForbiddenError`, `SourceEntityNotFoundError`, `SourceEntitySkippedError` +- `SourceFileDownloadError(file_url)` +- `SourceCreationError`, `SourceValidationError` + +Token provider exceptions (`domains/sources/token_providers/exceptions.py`): +- `TokenProviderError(SourceError)` — base for all token-provider failures + - `TokenCredentialsInvalidError` — expired/revoked token or refresh_token + - `TokenProviderAccountGoneError` — external account deleted (Composio/Pipedream) + - `TokenProviderConfigError` — fundamental misconfiguration + - `TokenProviderMissingCredsError` — response lacks required fields + - `TokenProviderRateLimitError` — upstream rate-limiting + - `TokenProviderServerError` — server error (5xx / timeout) + - `TokenRefreshNotSupportedError` — static token / no refresh_token + +OAuth refresh exceptions (`domains/oauth/exceptions.py`): +- `OAuthRefreshError` — base for all token-refresh failures + - `OAuthRefreshTokenRevokedError` — 401 (refresh_token dead) + - `OAuthRefreshBadRequestError` — 400 / invalid_grant + - `OAuthRefreshRateLimitError` — 429 / exhausted retries + - `OAuthRefreshServerError` — 5xx / timeout / connection error + - `OAuthRefreshCredentialMissingError` — no connection or credential in DB + +Auth provider exceptions (`platform/auth_providers/exceptions.py`): +- `AuthProviderError` — base for all auth-provider failures + - `AuthProviderAuthError` → `TokenCredentialsInvalidError` + - `AuthProviderAccountNotFoundError` → `TokenProviderAccountGoneError` + - `AuthProviderMissingFieldsError` → `TokenProviderMissingCredsError` + - `AuthProviderConfigError` → `TokenProviderConfigError` + - `AuthProviderRateLimitError` → `TokenProviderRateLimitError` + - `AuthProviderServerError` → `TokenProviderServerError` + +Each token provider translates its upstream exceptions into `TokenProviderError` subtypes. +Sources should raise `SourceError` subtypes. Pipeline routes by type (skip entity, retry, abort). + +### TokenProvider protocol (`domains/sources/token_providers/`) +Protocol with `get_token()` / `force_refresh()`. Three implementations: +- `OAuthTokenProvider` — expiry-aware refresh (80% of `expires_in`, clamped [60s, 50min]) + asyncio.Lock + tenacity retries (3× on 5xx/429) + delegates to `oauth2_service.refresh_and_persist()` +- `StaticTokenProvider` — raw string, raises `TokenRefreshNotSupportedError` (with `provider_kind` metadata) on `force_refresh()` +- `AuthProviderTokenProvider` — delegates to Pipedream/Composio `BaseAuthProvider.get_creds_for_source()` with tenacity retries (3× on server/rate-limit errors) + +`OAuthTokenProvider` constructor: `(credentials, *, oauth_type, oauth2_service, source_short_name, connection_id, ctx, logger, config_fields)`. `can_refresh` is derived internally from `oauth_type ∈ {with_refresh, with_rotating_refresh}` AND `_has_refresh_token(credentials)` (module-level helper). + +All providers translate upstream errors into `TokenProviderError` subtypes (see exception hierarchy above). + +**Resolution matrix** (`_configure_token_provider` in lifecycle.py): +| Situation | Provider | `can_refresh` | +|-----------|----------|---------------| +| Direct token injection (`access_token` param) | `StaticTokenProvider` | n/a | +| `oauth_type=None` (GitHub, Slab, Stripe, SharePoint2019) | None — source manages its own auth | n/a | +| `ACCESS_ONLY` (Linear, Slack, Notion, Monday, ClickUp, Todoist, Intercom) | `OAuthTokenProvider` | `False` | +| `WITH_REFRESH` (Airtable, Asana, Google Drive, Salesforce, etc.) | `OAuthTokenProvider` | `True` | +| `WITH_ROTATING_REFRESH` (Confluence, Jira) | `OAuthTokenProvider` | `True` | +| Auth provider (Pipedream/Composio, any source) | `AuthProviderTokenProvider` | n/a | +| Proxy mode | None (HTTP client handles auth) | n/a | + +### BaseSource token cleanup +- `_token_manager` removed. `_token_provider` is the single path. +- `get_access_token()` → `_token_provider.get_token()` +- `refresh_on_unauthorized()` → `_token_provider.force_refresh()` +- All ~25 source files migrated. Zero `token_manager` references remain. + +### oauth2_service +Added `refresh_and_persist()` — full load + decrypt + refresh + persist-rotation cycle in one call. OAuthTokenProvider is thin (~185 lines, zero infra deps). + +## What's remaining (prioritized) + +### Phase 1: Stop the bleeding (URGENT — production bug) +Sources `except Exception` and swallow everything. A 401 on every API call results in a "successful" sync with 0 entities (observed in Linear sync). ~186 `except Exception` instances across ~30 source files. + +**1a. Sources raise typed exceptions:** +- 401 after refresh attempt → `SourceAuthError` (abort sync) +- 429 → `SourceRateLimitError` (retry) +- 403 on single entity → `SourceEntityForbiddenError` (skip entity) +- 5xx / timeout → `SourceServerError` (retry) +- File download failure → `SourceFileDownloadError` (skip entity) + +**1b. Pipeline routes exceptions:** +- `SourceAuthError` → mark sync as auth_failed, abort +- `SourceEntityError` subtypes → skip entity, increment skipped counter +- `SourceRateLimitError` / `SourceServerError` → retry or abort after N retries +- Orchestrator (`platform/sync/orchestrator.py`), stream (`platform/sync/stream.py`) need updates + +### Phase 2: Kill setters, make construction explicit +**2a. Construction-time deps into `create()` params** — `logger`, `token_provider`, `http_client_factory` passed at creation, not via `set_*()` after. + +**2b. Operation-time deps into method params** — `generate_entities(*, cursor, files, node_selections)` instead of `self._cursor` / `self._file_downloader` stored on self. + +**2c. Update SourceLifecycleService** — passes deps at construction. + +**2d. Update SourceContextBuilder** — passes cursor/files/selections as method params. + +### Phase 3: Make `create()` typed +**3a. New `create()` signature** — Token sources: `create(*, auth: TokenProvider, config: Optional[BaseModel])`. Structured-credential sources: `create(*, auth: AuthConfig, config: Optional[BaseModel])`. + +**3b. Update all ~30 source implementations.** + +### Phase 4: Extract utilities from BaseSource +- `_validate_oauth2` (~150 lines) → standalone or on TokenProvider +- `clean_content_for_embedding` (~40 lines) → standalone utility +- `process_entities_concurrent` → stays (used by 5+ sources, only needs self.logger) + +### Phase 5: Lifecycle consolidation +- Validation path currently bypasses lifecycle service (direct `source_cls.create()`) — bring it in +- Absorb `SourceContextBuilder` sync-specific concerns into lifecycle service + +## Design decisions + +**HTTP client**: NOT abstracting auth into the HTTP layer. Auth header format is too diverse across sources (Bearer, `token`, NTLM, custom headers like `Notion-Version`). AirweaveHttpClient stays as transport + rate limiting only. Sources own their auth headers. + +**`get_token_for_resource`**: NOT on the TokenProvider protocol. Only SharePoint Online uses it. Accessed via `hasattr` check on _token_provider. The method exists on OAuthTokenProvider but not on the protocol. + +**No backwards compat**: We refactor thoroughly. No shims, no dual support. + +**No lazy imports**: Dependencies injected explicitly at construction. No `import` inside methods to avoid circular deps — fix the architecture instead. + +**No `Any` types**: Use protocols, typed params, `ContextualLogger` not `Any`. + +**Exceptions propagate**: Token providers throw `TokenProviderError` subtypes (which inherit from `SourceError`), not swallow errors. Pipeline/orchestrator decides what to do. + +**FileService concern**: When `generate_entities` receives `files: FileService` as param, FileService.download_from_url still needs the source's HTTP client and token provider. Either the pipeline pre-wires FileService with the source's auth, or the source passes them. This needs design thought. + +## Key files + +| File | Role | +|------|------| +| `platform/sources/_base.py` | BaseSource contract (setters still exist, needs `create()` rewrite) | +| `domains/sources/exceptions.py` | Full exception hierarchy (done) | +| `domains/sources/token_providers/` | Protocol + 3 implementations (done) | +| `domains/sources/lifecycle.py` | Creates + configures sources. Has `_configure_token_provider`. | +| `platform/builders/source.py` | Sync-specific: cursor, file_downloader, node_selections (to be absorbed) | +| `domains/oauth/oauth2_service.py` | `refresh_and_persist()` (done) | +| `tests/unit/platform/sync/test_token_providers.py` | 21 tests covering all 3 providers | + +## Source auth patterns (reference) + +| Pattern | Sources | How token is obtained | +|---------|---------|----------------------| +| Bearer OAuth | asana, slack, google_drive, hubspot, airtable, zendesk, sharepoint_online, notion, linear, jira | `get_access_token()` → TokenProvider | +| Bearer API key | attio, slab | `self.api_key` (static, no refresh) | +| `token` header (PAT) | github | `self.personal_access_token` | +| NTLM | sharepoint2019v2 | username/password/domain from auth config | +| Cross-resource | sharepoint_online | `get_token_for_resource(scope)` on OAuthTokenProvider | diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index daad8d330..46b403e91 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -384,6 +384,7 @@ def create_container(settings: Settings) -> Container: init_session_repo=init_session_repo, response_builder=sync_deps["response_builder"], source_registry=source_deps["source_registry"], + source_lifecycle=source_deps["source_lifecycle_service"], sync_lifecycle=sync_deps["sync_lifecycle"], sync_record_service=sync_deps["sync_record_service"], temporal_workflow_service=sync_deps["temporal_workflow_service"], diff --git a/backend/airweave/platform/auth_providers/_base.py b/backend/airweave/domains/auth_provider/_base.py similarity index 96% rename from backend/airweave/platform/auth_providers/_base.py rename to backend/airweave/domains/auth_provider/_base.py index add13ecdd..629962723 100644 --- a/backend/airweave/platform/auth_providers/_base.py +++ b/backend/airweave/domains/auth_provider/_base.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Set from airweave.core.logging import logger -from airweave.platform.auth_providers.auth_result import AuthResult +from airweave.domains.auth_provider.auth_result import AuthResult class BaseAuthProvider(ABC): @@ -66,7 +66,7 @@ async def validate(self) -> bool: True if the connection is valid, False otherwise Raises: - HTTPException: If validation fails with detailed error message + AuthProviderError: If validation fails (subclass depends on cause). """ pass diff --git a/backend/airweave/platform/auth_providers/auth_result.py b/backend/airweave/domains/auth_provider/auth_result.py similarity index 100% rename from backend/airweave/platform/auth_providers/auth_result.py rename to backend/airweave/domains/auth_provider/auth_result.py diff --git a/backend/airweave/domains/auth_provider/exceptions.py b/backend/airweave/domains/auth_provider/exceptions.py new file mode 100644 index 000000000..aa6f9831f --- /dev/null +++ b/backend/airweave/domains/auth_provider/exceptions.py @@ -0,0 +1,123 @@ +"""Auth provider domain exceptions. + +Hierarchy +--------- +AuthProviderError — base for all auth-provider failures +├── AuthProviderAuthError — provider rejected our credentials (401) +├── AuthProviderAccountNotFoundError— connected account not found (404) +├── AuthProviderMissingFieldsError — response lacks required credential fields +├── AuthProviderConfigError — app mismatch, unsupported source, etc. +├── AuthProviderRateLimitError — provider is throttling us (429) +└── AuthProviderServerError — 5xx / timeout / connection issue + +Translation to source exceptions (done by the token provider): + AuthProviderAuthError → SourceTokenRefreshError + AuthProviderAccountNotFoundError→ SourceServerError + AuthProviderMissingFieldsError → SourceServerError + AuthProviderConfigError → SourceServerError + AuthProviderRateLimitError → SourceRateLimitError + AuthProviderServerError → SourceServerError +""" + +from typing import Optional + + +class AuthProviderError(Exception): + """Base for all auth-provider runtime errors. + + Every subclass carries ``provider_name`` so callers can log and + route without inspecting the message. + """ + + def __init__(self, message: str, *, provider_name: str = ""): + """Initialize AuthProviderError.""" + self.provider_name = provider_name + super().__init__(message) + + +# -- Credential / auth failures ------------------------------------------------ + + +class AuthProviderAuthError(AuthProviderError): + """Provider rejected our credentials (HTTP 401). + + Typically means the client_id/client_secret or API key is invalid + or has been revoked. + """ + + pass + + +class AuthProviderAccountNotFoundError(AuthProviderError): + """Connected account does not exist in the provider (HTTP 404).""" + + def __init__(self, message: str, *, provider_name: str = "", account_id: str = ""): + """Initialize AuthProviderAccountNotFoundError.""" + self.account_id = account_id + super().__init__(message, provider_name=provider_name) + + +class AuthProviderMissingFieldsError(AuthProviderError): + """Provider response lacks required credential fields. + + The account exists and responded, but the credential dict does not + contain the fields the source requires (e.g. ``access_token``). + """ + + def __init__( + self, + message: str, + *, + provider_name: str = "", + missing_fields: Optional[list[str]] = None, + available_fields: Optional[list[str]] = None, + ): + """Initialize AuthProviderMissingFieldsError.""" + self.missing_fields = missing_fields or [] + self.available_fields = available_fields or [] + super().__init__(message, provider_name=provider_name) + + +class AuthProviderConfigError(AuthProviderError): + """Static configuration issue — wrong app, blocked source, etc. + + Retrying will never fix this; the connection setup is wrong. + """ + + pass + + +# -- Rate-limit / server errors ------------------------------------------------ + + +class AuthProviderRateLimitError(AuthProviderError): + """Provider is throttling requests (HTTP 429).""" + + def __init__( + self, + message: str = "Auth provider rate limit exceeded", + *, + provider_name: str = "", + retry_after: float = 30.0, + ): + """Initialize AuthProviderRateLimitError.""" + self.retry_after = retry_after + super().__init__(message, provider_name=provider_name) + + +class AuthProviderServerError(AuthProviderError): + """Server error — 5xx, timeout, connection refused.""" + + def __init__( + self, + message: str = "Auth provider server error", + *, + provider_name: str = "", + status_code: Optional[int] = None, + ): + """Initialize AuthProviderServerError.""" + self.status_code = status_code + super().__init__(message, provider_name=provider_name) + + +AuthProviderTemporaryError = AuthProviderServerError diff --git a/backend/airweave/platform/auth_providers/__init__.py b/backend/airweave/domains/auth_provider/providers/__init__.py similarity index 82% rename from backend/airweave/platform/auth_providers/__init__.py rename to backend/airweave/domains/auth_provider/providers/__init__.py index 4d3556c03..8d5dd9306 100644 --- a/backend/airweave/platform/auth_providers/__init__.py +++ b/backend/airweave/domains/auth_provider/providers/__init__.py @@ -1,4 +1,4 @@ -"""All auth provider connectors.""" +"""Auth provider implementations.""" from .composio import ComposioAuthProvider from .pipedream import PipedreamAuthProvider diff --git a/backend/airweave/platform/auth_providers/composio.py b/backend/airweave/domains/auth_provider/providers/composio.py similarity index 72% rename from backend/airweave/platform/auth_providers/composio.py rename to backend/airweave/domains/auth_provider/providers/composio.py index cc401eff9..7808b4333 100644 --- a/backend/airweave/platform/auth_providers/composio.py +++ b/backend/airweave/domains/auth_provider/providers/composio.py @@ -3,13 +3,18 @@ from typing import Any, Dict, List, Optional, Set import httpx -from fastapi import HTTPException from airweave.core.credential_sanitizer import ( safe_log_credentials, sanitize_credentials_dict, ) -from airweave.platform.auth_providers._base import BaseAuthProvider +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) from airweave.platform.configs.auth import ComposioAuthConfig from airweave.platform.configs.config import ComposioConfig from airweave.platform.decorators import auth_provider @@ -120,8 +125,12 @@ async def _get_with_auth( JSON response Raises: - httpx.HTTPStatusError: If the request fails + AuthProviderAuthError: 401 from Composio. + AuthProviderRateLimitError: 429 from Composio. + AuthProviderTemporaryError: 5xx or network error. """ + from airweave.domains.auth_provider.exceptions import AuthProviderAuthError + headers = {"x-api-key": self.api_key} try: @@ -129,11 +138,33 @@ async def _get_with_auth( response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: - self.logger.error(f"HTTP error from Composio API: {e.response.status_code} for {url}") - raise - except Exception as e: - self.logger.error(f"Unexpected error accessing Composio API: {url}, {str(e)}") + status = e.response.status_code + self.logger.error(f"HTTP error from Composio API: {status} for {url}") + if status == 401: + raise AuthProviderAuthError( + "Composio API key is invalid or revoked", + provider_name="composio", + ) from e + if status == 429: + retry_after = float(e.response.headers.get("retry-after", 30)) + raise AuthProviderRateLimitError( + "Composio API rate-limited", + provider_name="composio", + retry_after=retry_after, + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Composio API returned {status}", + provider_name="composio", + status_code=status, + ) from e raise + except (httpx.ConnectError, httpx.TimeoutException) as e: + self.logger.error(f"Network error accessing Composio API: {url}, {e}") + raise AuthProviderTemporaryError( + f"Composio API unreachable: {e}", + provider_name="composio", + ) from e async def _get_all_connected_accounts(self, client: httpx.AsyncClient) -> List[Dict[str, Any]]: """Fetch all connected accounts from Composio with pagination until exhaustion. @@ -190,7 +221,8 @@ async def get_creds_for_source( Credentials dictionary for the source Raises: - HTTPException: If no credentials found for the source + AuthProviderAccountNotFoundError: If the account is not found. + AuthProviderMissingFieldsError: If required credential fields are absent. """ # Map Airweave source name to Composio slug if needed composio_slug = self._get_composio_slug(source_short_name) @@ -234,7 +266,7 @@ async def get_creds_for_source( safe_log_credentials( found_credentials, self.logger.info, - f"\n🔑 [Composio] Retrieved credentials for '{source_short_name}':", + f"[Composio] Retrieved credentials for '{source_short_name}':", ) return found_credentials @@ -252,7 +284,7 @@ async def _get_source_connected_accounts( List of connected accounts for the source Raises: - HTTPException: If no accounts found for the source + AuthProviderAccountNotFoundError: If no accounts match the source. """ self.logger.info("🌐 [Composio] Fetching connected accounts from Composio API...") @@ -262,7 +294,7 @@ async def _get_source_connected_accounts( all_toolkits = { acc.get("toolkit", {}).get("slug", "unknown") for acc in all_connected_accounts } - self.logger.info(f"\n🔧 [Composio] Available toolkit slugs: {sorted(all_toolkits)}\n") + self.logger.info(f"[Composio] Available toolkit slugs: {sorted(all_toolkits)}") source_connected_accounts = [ connected_account @@ -271,19 +303,19 @@ async def _get_source_connected_accounts( ] self.logger.info( - f"\n🎯 [Composio] Found {len(source_connected_accounts)} accounts matching " - f"slug '{composio_slug}'\n" + f"[Composio] Found {len(source_connected_accounts)} accounts matching " + f"slug '{composio_slug}'" ) if not source_connected_accounts: self.logger.error( - f"\n❌ [Composio] No connected accounts found for slug '{composio_slug}'. " - f"Available slugs: {sorted(all_toolkits)}\n" + f"[Composio] No connected accounts found for slug '{composio_slug}'. " + f"Available slugs: {sorted(all_toolkits)}" ) - raise HTTPException( - status_code=404, - detail=f"No connected accounts found for source " - f"'{source_short_name}' (Composio slug: '{composio_slug}') in Composio.", + raise AuthProviderAccountNotFoundError( + f"No connected accounts found for source " + f"'{source_short_name}' (Composio slug: '{composio_slug}')", + provider_name="composio", ) # Log details of each matching account @@ -291,7 +323,7 @@ async def _get_source_connected_accounts( acc_id = account.get("id") int_id = account.get("auth_config", {}).get("id") self.logger.info( - f"\n 📌 Account {i + 1}: account_id='{acc_id}', auth_config_id='{int_id}'\n" + f"[Composio] Account {i + 1}: account_id='{acc_id}', auth_config_id='{int_id}'" ) return source_connected_accounts @@ -309,7 +341,7 @@ def _find_matching_connection( The credential dictionary for the matching connection Raises: - HTTPException: If no matching connection found + AuthProviderAccountNotFoundError: If no matching connection found. """ source_creds_dict = None @@ -325,8 +357,8 @@ def _find_matching_connection( if auth_config_id == self.auth_config_id and account_id == self.account_id: self.logger.info( - f"\n✅ [Composio] Found matching connection! " - f"auth_config_id='{auth_config_id}', account_id='{account_id}'\n" + f"[Composio] Found matching connection: " + f"auth_config_id='{auth_config_id}', account_id='{account_id}'" ) source_creds_dict = connected_account.get("state", {}).get("val") @@ -336,28 +368,25 @@ def _find_matching_connection( # Log available credential fields if source_creds_dict: available_fields = list(source_creds_dict.keys()) - self.logger.info( - f"\n🔓 [Composio] Available credential fields: {available_fields}\n" - ) + self.logger.info(f"[Composio] Available credential fields: {available_fields}") # Log credential fields safely without exposing values sanitized_preview = sanitize_credentials_dict( source_creds_dict, show_lengths=False ) - self.logger.debug( - f"\n🔓 [Composio] Credential fields preview: {sanitized_preview}\n" - ) + self.logger.debug(f"[Composio] Credential fields preview: {sanitized_preview}") break if not source_creds_dict: self.logger.error( - f"\n❌ [Composio] No matching connection found with " - f"auth_config_id='{self.auth_config_id}' and account_id='{self.account_id}'\n" + f"[Composio] No matching connection found with " + f"auth_config_id='{self.auth_config_id}' and account_id='{self.account_id}'" ) - raise HTTPException( - status_code=404, - detail=f"No matching connection in Composio with auth_config_id=" + raise AuthProviderAccountNotFoundError( + f"No matching Composio connection with auth_config_id=" f"'{self.auth_config_id}' and account_id='{self.account_id}' " - f"for source '{source_short_name}'.", + f"for source '{source_short_name}'", + provider_name="composio", + account_id=self.account_id, ) return source_creds_dict @@ -381,7 +410,7 @@ def _map_and_validate_fields( Dictionary with mapped credentials Raises: - HTTPException: If required (non-optional) fields are missing + AuthProviderMissingFieldsError: If required fields are absent. """ missing_required_fields = [] found_credentials = {} @@ -405,16 +434,15 @@ def _map_and_validate_fields( found = False for field_to_check in possible_fields: if field_to_check in source_creds_dict: - # Store with the original Airweave field name found_credentials[airweave_field] = source_creds_dict[field_to_check] if airweave_field != field_to_check: self.logger.info( - f"\n 🔄 Mapped field '{airweave_field}' to Composio field " - f"'{field_to_check}'\n" + f"[Composio] Mapped field '{airweave_field}' " + f"to Composio field '{field_to_check}'" ) self.logger.info( - f"\n ✅ Found field: '{airweave_field}' (as '{field_to_check}' " - f"in Composio)\n" + f"[Composio] Found field: '{airweave_field}' " + f"(as '{field_to_check}' in Composio)" ) found = True break @@ -422,34 +450,35 @@ def _map_and_validate_fields( if not found: if airweave_field in _optional_fields: self.logger.info( - f"\n ⏭️ Skipping optional field: '{airweave_field}' " - f"(not available in Composio)\n" + f"[Composio] Skipping optional field: '{airweave_field}' " + f"(not available in Composio)" ) else: missing_required_fields.append(airweave_field) self.logger.warning( - f"\n ❌ Missing required field: '{airweave_field}' (looked for " - f"{possible_fields} in Composio)\n" + f"[Composio] Missing required field: '{airweave_field}' " + f"(looked for {possible_fields} in Composio)" ) if missing_required_fields: available_fields = list(source_creds_dict.keys()) self.logger.error( - f"\n❌ [Composio] Missing required fields! " + f"[Composio] Missing required fields! " f"Required: {[f for f in source_auth_config_fields if f not in _optional_fields]}, " f"Missing: {missing_required_fields}, " - f"Available in Composio: {available_fields}\n" + f"Available in Composio: {available_fields}" ) - raise HTTPException( - status_code=422, - detail=f"Missing required auth fields for source '{source_short_name}': " - f"{missing_required_fields}. " - f"Available fields in Composio credentials: {available_fields}", + raise AuthProviderMissingFieldsError( + f"Missing required auth fields for source '{source_short_name}': " + f"{missing_required_fields}", + provider_name="composio", + missing_fields=missing_required_fields, + available_fields=available_fields, ) self.logger.info( - f"\n✅ [Composio] Successfully retrieved {len(found_credentials)} " - f"credential fields for source '{source_short_name}'\n" + f"[Composio] Successfully retrieved {len(found_credentials)} " + f"credential fields for source '{source_short_name}'" ) return found_credentials @@ -485,18 +514,23 @@ async def validate(self) -> bool: """Validate that the Composio connection works by testing API access. Returns: - True if the connection is valid + True if the connection is valid. Raises: - HTTPException: If validation fails with detailed error message + AuthProviderAuthError: Invalid or revoked API key (401/403). + AuthProviderTemporaryError: Transient failure from Composio. + AuthProviderConfigError: Other non-transient failure. """ + from airweave.domains.auth_provider.exceptions import ( + AuthProviderAuthError, + AuthProviderConfigError, + ) + try: self.logger.info("🔍 [Composio] Validating API key...") async with httpx.AsyncClient() as client: headers = {"x-api-key": self.api_key} - - # Test API access with the v3 connected accounts endpoint url = "https://backend.composio.dev/api/v3/connected_accounts" response = await client.get(url, headers=headers) response.raise_for_status() @@ -505,22 +539,30 @@ async def validate(self) -> bool: return True except httpx.HTTPStatusError as e: - error_msg = f"Composio API key validation failed: {e.response.status_code}" - if e.response.status_code == 401: - error_msg += " - Invalid API key" - elif e.response.status_code == 403: - error_msg += " - Access denied" - else: - try: - error_detail = e.response.json().get("message", e.response.text) - error_msg += f" - {error_detail}" - except Exception: - error_msg += f" - {e.response.text}" + status = e.response.status_code + if status in (401, 403): + error_msg = ( + f"Composio API key validation failed: {status} - " + f"{'Invalid API key' if status == 401 else 'Access denied'}" + ) + self.logger.error(f"❌ [Composio] {error_msg}") + raise AuthProviderAuthError(error_msg, provider_name="composio") from e + + try: + detail = e.response.json().get("message", e.response.text) + except Exception: + detail = e.response.text + error_msg = f"Composio API key validation failed: {status} - {detail}" self.logger.error(f"❌ [Composio] {error_msg}") - raise HTTPException(status_code=422, detail=error_msg) - except Exception as e: - error_msg = f"Composio API key validation failed: {str(e)}" + if status >= 500: + raise AuthProviderTemporaryError( + error_msg, provider_name="composio", status_code=status + ) from e + raise AuthProviderConfigError(error_msg, provider_name="composio") from e + + except (httpx.ConnectError, httpx.TimeoutException) as e: + error_msg = f"Composio unreachable during validation: {e}" self.logger.error(f"❌ [Composio] {error_msg}") - raise HTTPException(status_code=422, detail=error_msg) + raise AuthProviderTemporaryError(error_msg, provider_name="composio") from e diff --git a/backend/airweave/platform/auth_providers/klavis.py b/backend/airweave/domains/auth_provider/providers/klavis.py similarity index 100% rename from backend/airweave/platform/auth_providers/klavis.py rename to backend/airweave/domains/auth_provider/providers/klavis.py diff --git a/backend/airweave/platform/auth_providers/pipedream.py b/backend/airweave/domains/auth_provider/providers/pipedream.py similarity index 73% rename from backend/airweave/platform/auth_providers/pipedream.py rename to backend/airweave/domains/auth_provider/providers/pipedream.py index 03b5025a2..7b169cd24 100644 --- a/backend/airweave/platform/auth_providers/pipedream.py +++ b/backend/airweave/domains/auth_provider/providers/pipedream.py @@ -4,11 +4,18 @@ from typing import Any, Dict, List, Optional, Set import httpx -from fastapi import HTTPException from airweave.core.credential_sanitizer import safe_log_credentials -from airweave.platform.auth_providers._base import BaseAuthProvider -from airweave.platform.auth_providers.auth_result import AuthResult +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.auth_result import AuthResult +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) from airweave.platform.configs.auth import PipedreamAuthConfig from airweave.platform.configs.config import PipedreamConfig from airweave.platform.decorators import auth_provider @@ -173,7 +180,9 @@ async def _ensure_valid_token(self) -> str: A valid access token Raises: - HTTPException: If token refresh fails + AuthProviderAuthError: If client credentials are rejected. + AuthProviderRateLimitError: If token endpoint is rate-limited. + AuthProviderTemporaryError: If token endpoint returns a server error. """ current_time = time.time() @@ -210,13 +219,26 @@ async def _ensure_valid_token(self) -> str: return self._access_token except httpx.HTTPStatusError as e: + status = e.response.status_code self.logger.error( - f"❌ [Pipedream] Failed to refresh token: {e.response.status_code} - " - f"{e.response.text}" + f"❌ [Pipedream] Failed to refresh token: {status} - {e.response.text}" ) - raise HTTPException( - status_code=500, - detail=f"Failed to refresh Pipedream access token: {e.response.text}", + if status == 401: + raise AuthProviderAuthError( + f"Pipedream rejected client credentials ({status})", + provider_name="pipedream", + ) from e + if status == 429: + retry_after = float(e.response.headers.get("retry-after", 30)) + raise AuthProviderRateLimitError( + "Pipedream token endpoint rate-limited", + provider_name="pipedream", + retry_after=retry_after, + ) from e + raise AuthProviderTemporaryError( + f"Pipedream token endpoint returned {status}", + provider_name="pipedream", + status_code=status, ) from e async def _get_with_auth( @@ -270,7 +292,8 @@ async def get_creds_for_source( Credentials dictionary for the source Raises: - HTTPException: If no credentials found for the source + AuthProviderAccountNotFoundError: If the account is not found. + AuthProviderMissingFieldsError: If required credential fields are absent. """ # Map Airweave source name to Pipedream app slug if needed pipedream_app_slug = self._get_pipedream_app_slug(source_short_name) @@ -309,7 +332,7 @@ async def get_creds_for_source( safe_log_credentials( found_credentials, self.logger.info, - f"\n🔑 [Pipedream] Retrieved credentials for '{source_short_name}':", + f"[Pipedream] Retrieved credentials for '{source_short_name}':", ) return found_credentials @@ -317,16 +340,17 @@ async def validate(self) -> bool: """Validate that the Pipedream connection works by testing client credentials. Returns: - True if the connection is valid + True if the connection is valid. Raises: - HTTPException: If validation fails with detailed error message + AuthProviderAuthError: Invalid client credentials (401). + AuthProviderConfigError: Bad request or unexpected response. + AuthProviderTemporaryError: Transient failure from Pipedream. """ try: self.logger.info("🔍 [Pipedream] Validating client credentials...") async with httpx.AsyncClient() as client: - # Test OAuth token generation with client credentials token_data = { "grant_type": "client_credentials", "client_id": self.client_id, @@ -338,37 +362,45 @@ async def validate(self) -> bool: token_response = response.json() if "access_token" not in token_response: - raise HTTPException( - status_code=422, detail="Pipedream API returned invalid token response" + raise AuthProviderConfigError( + "Pipedream API returned token response without access_token", + provider_name="pipedream", ) self.logger.info("✅ [Pipedream] Client credentials validated successfully") return True + except AuthProviderConfigError: + raise except httpx.HTTPStatusError as e: - error_msg = f"Pipedream client credentials validation failed: {e.response.status_code}" - if e.response.status_code == 401: - error_msg += " - Invalid client credentials" - elif e.response.status_code == 400: - try: - error_detail = e.response.json().get("error_description", e.response.text) - error_msg += f" - {error_detail}" - except Exception: - error_msg += " - Bad request" - else: - try: - error_detail = e.response.json().get("error", e.response.text) - error_msg += f" - {error_detail}" - except Exception: - error_msg += f" - {e.response.text}" + status = e.response.status_code + if status == 401: + error_msg = ( + "Pipedream client credentials validation failed: Invalid client credentials" + ) + self.logger.error(f"❌ [Pipedream] {error_msg}") + raise AuthProviderAuthError(error_msg, provider_name="pipedream") from e + + try: + detail = e.response.json().get( + "error_description", e.response.json().get("error", e.response.text) + ) + except Exception: + detail = e.response.text + error_msg = f"Pipedream client credentials validation failed: {status} - {detail}" self.logger.error(f"❌ [Pipedream] {error_msg}") - raise HTTPException(status_code=422, detail=error_msg) - except Exception as e: - error_msg = f"Pipedream client credentials validation failed: {str(e)}" + if status >= 500: + raise AuthProviderTemporaryError( + error_msg, provider_name="pipedream", status_code=status + ) from e + raise AuthProviderConfigError(error_msg, provider_name="pipedream") from e + + except (httpx.ConnectError, httpx.TimeoutException) as e: + error_msg = f"Pipedream unreachable during validation: {e}" self.logger.error(f"❌ [Pipedream] {error_msg}") - raise HTTPException(status_code=422, detail=error_msg) + raise AuthProviderTemporaryError(error_msg, provider_name="pipedream") from e async def _get_account_with_credentials( self, client: httpx.AsyncClient, pipedream_app_slug: str, source_short_name: str @@ -384,34 +416,35 @@ async def _get_account_with_credentials( Account data with credentials Raises: - HTTPException: If account not found or credentials not available + AuthProviderAccountNotFoundError: Account does not exist. + AuthProviderConfigError: Account is for a different app. + PipedreamDefaultOAuthException: Default OAuth — caller should use proxy. + AuthProviderTemporaryError: Transient HTTP failure. """ - # Build the API URL for the specific account url = f"https://api.pipedream.com/v1/connect/{self.project_id}/accounts/{self.account_id}" self.logger.info(f"🌐 [Pipedream] Fetching account from: {url}") try: - # Include credentials in the response params = {"include_credentials": "true"} account_data = await self._get_with_auth(client, url, params) - # Verify it's the right app if account_data.get("app", {}).get("name_slug") != pipedream_app_slug: + actual = account_data.get("app", {}).get("name_slug") self.logger.error( - f"❌ [Pipedream] Account app mismatch. Expected '{pipedream_app_slug}', " - f"got '{account_data.get('app', {}).get('name_slug')}'" + f"❌ [Pipedream] Account app mismatch. " + f"Expected '{pipedream_app_slug}', got '{actual}'" ) - raise HTTPException( - status_code=400, - detail=f"Account {self.account_id} is not for app '{pipedream_app_slug}'", + raise AuthProviderConfigError( + f"Account {self.account_id} is for app '{actual}', " + f"expected '{pipedream_app_slug}'", + provider_name="pipedream", ) - # Check if credentials are included if "credentials" not in account_data: self.logger.error( - "❌ [Pipedream] No credentials in response. This usually means the account " - "was created with Pipedream's default OAuth client, not a custom one." + "❌ [Pipedream] No credentials in response. This usually means " + "the account was created with Pipedream's default OAuth client." ) raise PipedreamDefaultOAuthException(source_short_name) @@ -422,14 +455,33 @@ async def _get_account_with_credentials( return account_data + except (AuthProviderConfigError, PipedreamDefaultOAuthException): + raise except httpx.HTTPStatusError as e: - if e.response.status_code == 404: - raise HTTPException( - status_code=404, detail=f"Pipedream account not found: {self.account_id}" + status = e.response.status_code + if status == 404: + raise AuthProviderAccountNotFoundError( + f"Pipedream account not found: {self.account_id}", + provider_name="pipedream", + account_id=self.account_id, + ) from e + if status == 429: + retry_after = float(e.response.headers.get("retry-after", 30)) + raise AuthProviderRateLimitError( + "Pipedream API rate-limited while fetching account", + provider_name="pipedream", + retry_after=retry_after, + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Pipedream API returned {status} while fetching account", + provider_name="pipedream", + status_code=status, ) from e - raise HTTPException( - status_code=e.response.status_code, - detail=f"Failed to fetch Pipedream account: {e.response.text}", + raise AuthProviderTemporaryError( + f"Pipedream API error {status}: {e.response.text[:200]}", + provider_name="pipedream", + status_code=status, ) from e def _extract_and_map_credentials( @@ -451,7 +503,7 @@ def _extract_and_map_credentials( Dictionary with mapped credentials Raises: - HTTPException: If required (non-optional) fields are missing + AuthProviderMissingFieldsError: If required fields are absent. """ credentials = account_data.get("credentials", {}) missing_required_fields = [] @@ -462,55 +514,54 @@ def _extract_and_map_credentials( self.logger.info(f"📦 [Pipedream] Available credential fields: {list(credentials.keys())}") for airweave_field in source_auth_config_fields: - # Map the field name if needed (per-source override, then global) pipedream_field = self._map_field_name( airweave_field, source_short_name=source_short_name ) if airweave_field != pipedream_field: self.logger.info( - f"\n 🔄 Mapped field '{airweave_field}' to Pipedream field " - f"'{pipedream_field}'\n" + f"[Pipedream] Mapped field '{airweave_field}' " + f"to Pipedream field '{pipedream_field}'" ) if pipedream_field in credentials: - # Store with the original Airweave field name found_credentials[airweave_field] = credentials[pipedream_field] self.logger.info( - f"\n ✅ Found field: '{airweave_field}' (as '{pipedream_field}' " - f"in Pipedream)\n" + f"[Pipedream] Found field: '{airweave_field}' " + f"(as '{pipedream_field}' in Pipedream)" ) else: if airweave_field in _optional_fields: self.logger.info( - f"\n ⏭️ Skipping optional field: '{airweave_field}' " - f"(not available in Pipedream)\n" + f"[Pipedream] Skipping optional field: '{airweave_field}' " + f"(not available in Pipedream)" ) else: missing_required_fields.append(airweave_field) self.logger.warning( - f"\n ❌ Missing required field: '{airweave_field}' (looked for " - f"'{pipedream_field}' in Pipedream)\n" + f"[Pipedream] Missing required field: '{airweave_field}' " + f"(looked for '{pipedream_field}' in Pipedream)" ) if missing_required_fields: available_fields = list(credentials.keys()) self.logger.error( - f"\n❌ [Pipedream] Missing required fields! " + f"[Pipedream] Missing required fields! " f"Required: {[f for f in source_auth_config_fields if f not in _optional_fields]}, " f"Missing: {missing_required_fields}, " - f"Available in Pipedream: {available_fields}\n" + f"Available in Pipedream: {available_fields}" ) - raise HTTPException( - status_code=422, - detail=f"Missing required auth fields for source '{source_short_name}': " - f"{missing_required_fields}. " - f"Available fields in Pipedream credentials: {available_fields}", + raise AuthProviderMissingFieldsError( + f"Missing required auth fields for source '{source_short_name}': " + f"{missing_required_fields}", + provider_name="pipedream", + missing_fields=missing_required_fields, + available_fields=available_fields, ) self.logger.info( - f"\n✅ [Pipedream] Successfully retrieved {len(found_credentials)} " - f"credential fields for source '{source_short_name}'\n" + f"[Pipedream] Successfully retrieved {len(found_credentials)} " + f"credential fields for source '{source_short_name}'" ) return found_credentials diff --git a/backend/airweave/domains/auth_provider/registry.py b/backend/airweave/domains/auth_provider/registry.py index a587168c3..5fa438ccc 100644 --- a/backend/airweave/domains/auth_provider/registry.py +++ b/backend/airweave/domains/auth_provider/registry.py @@ -3,7 +3,7 @@ from airweave.core.logging import logger from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol from airweave.domains.auth_provider.types import AuthProviderRegistryEntry -from airweave.platform.auth_providers import ALL_AUTH_PROVIDERS +from airweave.domains.auth_provider.providers import ALL_AUTH_PROVIDERS from airweave.platform.configs._base import Fields registry_logger = logger.with_prefix("AuthProviderRegistry: ").with_context( diff --git a/backend/airweave/domains/auth_provider/tests/test_composio.py b/backend/airweave/domains/auth_provider/tests/test_composio.py new file mode 100644 index 000000000..64a793a32 --- /dev/null +++ b/backend/airweave/domains/auth_provider/tests/test_composio.py @@ -0,0 +1,610 @@ +"""Unit tests for ComposioAuthProvider.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.domains.auth_provider.providers.composio import ComposioAuthProvider + + +# --------------------------------------------------------------------------- +# create() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_sets_instance_attrs(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "test-key"}, + config={"auth_config_id": "cfg-1", "account_id": "acc-1"}, + ) + assert provider.api_key == "test-key" + assert provider.auth_config_id == "cfg-1" + assert provider.account_id == "acc-1" + assert provider._last_credential_blob is None + + +@pytest.mark.asyncio +async def test_create_handles_partial_config(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={}, + ) + assert provider.api_key == "k" + assert provider.auth_config_id is None + assert provider.account_id is None + + +# --------------------------------------------------------------------------- +# _get_with_auth() — error branches (401, 429, 5xx, ConnectError, TimeoutException) +# --------------------------------------------------------------------------- + + +def _make_http_status_error(status_code: int, headers: dict | None = None): + resp = MagicMock() + resp.status_code = status_code + resp.headers = headers or {} + return httpx.HTTPStatusError("err", request=MagicMock(), response=resp) + + +@pytest.mark.asyncio +async def test_get_with_auth_401_raises_auth_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=_make_http_status_error(401), + ) + + with pytest.raises(AuthProviderAuthError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert "invalid or revoked" in str(exc_info.value).lower() + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_with_auth_429_raises_rate_limit_with_retry_after(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + resp = MagicMock() + resp.status_code = 429 + resp.headers = {"retry-after": "120"} + mock_client.get = AsyncMock(side_effect=httpx.HTTPStatusError("err", request=MagicMock(), response=resp)) + + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert exc_info.value.retry_after == 120.0 + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_with_auth_429_default_retry_after(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + resp = MagicMock() + resp.status_code = 429 + resp.headers = {} + mock_client.get = AsyncMock(side_effect=httpx.HTTPStatusError("err", request=MagicMock(), response=resp)) + + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert exc_info.value.retry_after == 30.0 + + +@pytest.mark.asyncio +async def test_get_with_auth_5xx_raises_temporary_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_http_status_error(502)) + + with pytest.raises(AuthProviderTemporaryError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert "502" in str(exc_info.value) + assert exc_info.value.status_code == 502 + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_with_auth_connect_error_raises_temporary(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) + + with pytest.raises(AuthProviderTemporaryError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert "unreachable" in str(exc_info.value).lower() or "connection" in str(exc_info.value).lower() + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_with_auth_timeout_raises_temporary(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=httpx.TimeoutException("timeout")) + + with pytest.raises(AuthProviderTemporaryError) as exc_info: + await provider._get_with_auth(mock_client, "https://example.com/api") + + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_with_auth_success_returns_json(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + mock_client = AsyncMock() + mock_resp = MagicMock() + mock_resp.json.return_value = {"items": [], "total": 0} + mock_resp.raise_for_status = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + + result = await provider._get_with_auth(mock_client, "https://example.com/api") + assert result == {"items": [], "total": 0} + + +# --------------------------------------------------------------------------- +# _get_source_connected_accounts() — no match -> AuthProviderAccountNotFoundError +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_source_connected_accounts_no_match_raises(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "cfg-1", "account_id": "acc-1"}, + ) + # All accounts have different slug + provider._get_all_connected_accounts = AsyncMock( + return_value=[ + {"id": "a1", "toolkit": {"slug": "slack"}, "auth_config": {"id": "x"}}, + ] + ) + + with pytest.raises(AuthProviderAccountNotFoundError) as exc_info: + await provider._get_source_connected_accounts( + MagicMock(), composio_slug="googledrive", source_short_name="google_drive" + ) + + assert "No connected accounts" in str(exc_info.value) + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_get_source_connected_accounts_empty_list_raises(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "cfg-1", "account_id": "acc-1"}, + ) + provider._get_all_connected_accounts = AsyncMock(return_value=[]) + + with pytest.raises(AuthProviderAccountNotFoundError): + await provider._get_source_connected_accounts( + MagicMock(), composio_slug="slack", source_short_name="slack" + ) + + +# --------------------------------------------------------------------------- +# _find_matching_connection() — match caches blob; no match raises +# --------------------------------------------------------------------------- + + +def test_find_matching_connection_match_caches_blob(): + provider = ComposioAuthProvider() + provider.auth_config_id = "cfg-1" + provider.account_id = "acc-1" + provider._last_credential_blob = None + + creds = {"generic_api_key": "secret", "instance_url": "https://x.com"} + accounts = [ + { + "id": "acc-1", + "auth_config": {"id": "cfg-1"}, + "state": {"val": creds}, + } + ] + + result = provider._find_matching_connection(accounts, "salesforce") + assert result == creds + assert provider._last_credential_blob == creds + + +def test_find_matching_connection_no_match_raises(): + provider = ComposioAuthProvider() + provider.auth_config_id = "cfg-1" + provider.account_id = "acc-1" + + accounts = [ + {"id": "other", "auth_config": {"id": "other-cfg"}, "state": {"val": {}}}, + ] + + with pytest.raises(AuthProviderAccountNotFoundError) as exc_info: + provider._find_matching_connection(accounts, "slack") + + assert "No matching Composio connection" in str(exc_info.value) + assert exc_info.value.account_id == "acc-1" + assert exc_info.value.provider_name == "composio" + + +# --------------------------------------------------------------------------- +# _map_and_validate_fields() — missing required raises +# --------------------------------------------------------------------------- + + +def test_map_and_validate_fields_missing_required_raises(): + provider = ComposioAuthProvider() + # No api_key, generic_api_key, or access_token — all required fields missing + source_creds = {"other_field": "x"} + + with pytest.raises(AuthProviderMissingFieldsError) as exc_info: + provider._map_and_validate_fields( + source_creds, + source_auth_config_fields=["api_key", "refresh_token"], + source_short_name="slack", + ) + + assert "api_key" in exc_info.value.missing_fields + assert "refresh_token" in exc_info.value.missing_fields + assert exc_info.value.available_fields == ["other_field"] + assert exc_info.value.provider_name == "composio" + + +def test_map_and_validate_fields_optional_skipped(): + provider = ComposioAuthProvider() + source_creds = {"generic_api_key": "key"} + + result = provider._map_and_validate_fields( + source_creds, + source_auth_config_fields=["api_key", "optional_field"], + source_short_name="stripe", + optional_fields={"optional_field"}, + ) + assert result == {"api_key": "key"} + + +def test_map_and_validate_fields_api_key_tries_multiple(): + provider = ComposioAuthProvider() + # api_key maps to generic_api_key, but we also try access_token + source_creds = {"access_token": "oauth-tok"} + + result = provider._map_and_validate_fields( + source_creds, + source_auth_config_fields=["api_key"], + source_short_name="stripe", + ) + assert result == {"api_key": "oauth-tok"} + + +# --------------------------------------------------------------------------- +# get_config_for_source() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_config_for_source_extracts_from_cached_blob(): + provider = ComposioAuthProvider() + provider._last_credential_blob = {"instance_url": "https://my.salesforce.com", "org_id": "00D123"} + + result = await provider.get_config_for_source( + source_short_name="salesforce", + source_config_field_mappings={"instance_url": "instance_url", "org_id": "org_id"}, + ) + assert result == {"instance_url": "https://my.salesforce.com", "org_id": "00D123"} + + +@pytest.mark.asyncio +async def test_get_config_for_source_empty_blob_returns_empty(): + provider = ComposioAuthProvider() + provider._last_credential_blob = None + + result = await provider.get_config_for_source( + source_short_name="slack", + source_config_field_mappings={"team_id": "team_id"}, + ) + assert result == {} + + +# --------------------------------------------------------------------------- +# validate() — 401, 403, 5xx, other status, ConnectError, non-JSON fallback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_401_raises_auth_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 401 + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "401", request=MagicMock(), response=mock_resp + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderAuthError) as exc_info: + await provider.validate() + + assert "Invalid API key" in str(exc_info.value) + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_validate_403_raises_auth_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 403 + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "403", request=MagicMock(), response=mock_resp + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderAuthError) as exc_info: + await provider.validate() + + assert "Access denied" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_validate_5xx_raises_temporary_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 503 + mock_resp.json.return_value = {"message": "Service unavailable"} + mock_resp.text = "Service unavailable" + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "503", request=MagicMock(), response=mock_resp + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError) as exc_info: + await provider.validate() + + assert "503" in str(exc_info.value) + assert exc_info.value.status_code == 503 + + +@pytest.mark.asyncio +async def test_validate_other_status_raises_config_error(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 400 + mock_resp.json.return_value = {"message": "Bad request"} + mock_resp.text = "Bad request" + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "400", request=MagicMock(), response=mock_resp + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError) as exc_info: + await provider.validate() + + assert "400" in str(exc_info.value) + assert exc_info.value.provider_name == "composio" + + +@pytest.mark.asyncio +async def test_validate_non_json_fallback_uses_text(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 400 + mock_resp.json.side_effect = ValueError("not json") + mock_resp.text = "plain text error" + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "400", request=MagicMock(), response=mock_resp + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError) as exc_info: + await provider.validate() + + assert "plain text error" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_validate_connect_error_raises_temporary(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError) as exc_info: + await provider.validate() + + assert "unreachable" in str(exc_info.value).lower() or "connection" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_validate_timeout_raises_temporary(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=httpx.TimeoutException("timeout")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_success_returns_true(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "c", "account_id": "a"}, + ) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + result = await provider.validate() + + assert result is True + + +# --------------------------------------------------------------------------- +# get_creds_for_source() — integration via mocked HTTP +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_creds_for_source_happy_path(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "test-key"}, + config={"auth_config_id": "cfg-1", "account_id": "acc-1"}, + ) + + def mock_get(url, headers=None, params=None): + resp = MagicMock() + resp.json.return_value = { + "items": [ + { + "id": "acc-1", + "auth_config": {"id": "cfg-1"}, + "toolkit": {"slug": "slack"}, + "state": {"val": {"access_token": "tok123", "oauth_access_token": "tok123"}}, + } + ] + } + resp.raise_for_status = MagicMock() + return resp + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=mock_get) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + creds = await provider.get_creds_for_source( + source_short_name="slack", + source_auth_config_fields=["access_token"], + ) + + assert creds == {"access_token": "tok123"} + + +@pytest.mark.asyncio +async def test_get_creds_for_source_slug_mapping(): + provider = await ComposioAuthProvider.create( + credentials={"api_key": "k"}, + config={"auth_config_id": "cfg-1", "account_id": "acc-1"}, + ) + + def mock_get(url, headers=None, params=None): + resp = MagicMock() + resp.json.return_value = { + "items": [ + { + "id": "acc-1", + "auth_config": {"id": "cfg-1"}, + "toolkit": {"slug": "googledrive"}, + "state": {"val": {"access_token": "tok"}}, + } + ] + } + resp.raise_for_status = MagicMock() + return resp + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=mock_get) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("airweave.domains.auth_provider.providers.composio.httpx.AsyncClient", return_value=mock_client): + creds = await provider.get_creds_for_source( + source_short_name="google_drive", + source_auth_config_fields=["access_token"], + ) + + assert creds == {"access_token": "tok"} diff --git a/backend/airweave/domains/auth_provider/tests/test_exceptions.py b/backend/airweave/domains/auth_provider/tests/test_exceptions.py new file mode 100644 index 000000000..c88e90f7c --- /dev/null +++ b/backend/airweave/domains/auth_provider/tests/test_exceptions.py @@ -0,0 +1,55 @@ +"""Tests for auth provider exception __init__ bodies.""" + +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderServerError, + AuthProviderTemporaryError, +) + + +def test_base_error_stores_provider_name(): + err = AuthProviderError("boom", provider_name="pipedream") + assert err.provider_name == "pipedream" + assert str(err) == "boom" + + +def test_account_not_found_stores_account_id(): + err = AuthProviderAccountNotFoundError( + "gone", provider_name="composio", account_id="acc-123" + ) + assert err.account_id == "acc-123" + assert err.provider_name == "composio" + + +def test_missing_fields_stores_fields(): + err = AuthProviderMissingFieldsError( + "missing", + provider_name="pipedream", + missing_fields=["api_key"], + available_fields=["access_token", "refresh_token"], + ) + assert err.missing_fields == ["api_key"] + assert err.available_fields == ["access_token", "refresh_token"] + + +def test_missing_fields_defaults_to_empty(): + err = AuthProviderMissingFieldsError("missing") + assert err.missing_fields == [] + assert err.available_fields == [] + + +def test_rate_limit_stores_retry_after(): + err = AuthProviderRateLimitError("slow down", provider_name="composio", retry_after=60.0) + assert err.retry_after == 60.0 + + +def test_server_error_stores_status_code(): + err = AuthProviderServerError("oops", provider_name="pipedream", status_code=502) + assert err.status_code == 502 + + +def test_temporary_error_is_server_error_alias(): + assert AuthProviderTemporaryError is AuthProviderServerError diff --git a/backend/airweave/domains/auth_provider/tests/test_pipedream.py b/backend/airweave/domains/auth_provider/tests/test_pipedream.py new file mode 100644 index 000000000..e58a3eb35 --- /dev/null +++ b/backend/airweave/domains/auth_provider/tests/test_pipedream.py @@ -0,0 +1,723 @@ +"""Tests for PipedreamAuthProvider.""" + +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from airweave.domains.auth_provider.auth_result import AuthResult +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.domains.auth_provider.providers.pipedream import ( + PipedreamAuthProvider, + PipedreamDefaultOAuthException, +) + +PIPEDREAM_MODULE = "airweave.domains.auth_provider.providers.pipedream" + + +def _make_response(status_code: int = 200, json_body=None, text: str = "", headers=None): + """Build httpx.Response with controllable body and headers.""" + content = json.dumps(json_body).encode() if json_body is not None else text.encode() + hdrs = dict(headers or {}) + if json_body is not None and "content-type" not in (k.lower() for k in hdrs): + hdrs["content-type"] = "application/json" + return httpx.Response( + status_code=status_code, + content=content, + headers=hdrs, + request=httpx.Request("POST", "https://api.pipedream.com/v1/oauth/token"), + ) + + +# --------------------------------------------------------------------------- +# create() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_success(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj-1", "account_id": "acc-1", "environment": "production"}, + ) + assert provider.client_id == "cid" + assert provider.client_secret == "csec" + assert provider.project_id == "proj-1" + assert provider.account_id == "acc-1" + assert provider.environment == "production" + assert provider._access_token is None + assert provider._token_expires_at == 0 + + +@pytest.mark.asyncio +async def test_create_credentials_none_raises(): + with pytest.raises(ValueError, match="credentials parameter is required"): + await PipedreamAuthProvider.create(credentials=None) + + +@pytest.mark.asyncio +async def test_create_config_none_uses_defaults(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config=None, + ) + assert provider.project_id is None + assert provider.account_id is None + assert provider.environment == "production" + + +# --------------------------------------------------------------------------- +# PipedreamDefaultOAuthException +# --------------------------------------------------------------------------- + + +def test_pipedream_default_oauth_exception_default_message(): + exc = PipedreamDefaultOAuthException("slack") + assert exc.source_short_name == "slack" + assert "slack" in str(exc) + assert "default OAuth" in str(exc) + + +def test_pipedream_default_oauth_exception_custom_message(): + exc = PipedreamDefaultOAuthException("slack", message="custom") + assert str(exc) == "custom" + + +# --------------------------------------------------------------------------- +# _ensure_valid_token() error branches (lines 222, 226-227, 231-233, 238) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ensure_valid_token_401_raises_auth_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(401, {"error": "invalid_client"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("401", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderAuthError, match="rejected client credentials"): + await provider._ensure_valid_token() + + +@pytest.mark.asyncio +async def test_ensure_valid_token_429_raises_rate_limit_with_retry_after(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(429, {"error": "rate_limited"}, headers={"retry-after": "120"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("429", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider._ensure_valid_token() + assert exc_info.value.retry_after == 120.0 + + +@pytest.mark.asyncio +async def test_ensure_valid_token_429_default_retry_after(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(429, {"error": "rate_limited"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("429", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider._ensure_valid_token() + assert exc_info.value.retry_after == 30.0 + + +@pytest.mark.asyncio +async def test_ensure_valid_token_500_raises_temporary_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(500, {"error": "internal"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("500", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError, match="returned 500"): + await provider._ensure_valid_token() + + +# --------------------------------------------------------------------------- +# validate() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_success(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + mock_response = _make_response(200, {"access_token": "tok", "expires_in": 3600}) + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + result = await provider.validate() + assert result is True + + +@pytest.mark.asyncio +async def test_validate_missing_access_token_raises_config_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + mock_response = _make_response(200, {"token_type": "bearer"}) + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError, match="without access_token"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_401_raises_auth_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(401, {"error": "invalid_client"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("401", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderAuthError, match="Invalid client credentials"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_5xx_raises_temporary_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(502, {"error": "bad_gateway", "error_description": "upstream down"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("502", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError, match="502"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_4xx_other_raises_config_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(400, {"error": "invalid_request", "error_description": "bad params"}) + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("400", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError, match="400"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_detail_extraction_error_fallback(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + resp = _make_response(400, text="not json at all") + + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError("400", request=resp.request, response=resp)) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError, match="not json at all"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_connect_error_raises_temporary(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + + mock_post = AsyncMock(side_effect=httpx.ConnectError("connection refused")) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_timeout_raises_temporary(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + + mock_post = AsyncMock(side_effect=httpx.TimeoutException("timed out")) + + mock_client = AsyncMock() + mock_client.post = mock_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.validate() + + +@pytest.mark.asyncio +async def test_validate_reraises_config_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + mock_response = _make_response(200, {"token_type": "bearer"}) + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch(f"{PIPEDREAM_MODULE}.httpx.AsyncClient", return_value=mock_client): + with pytest.raises(AuthProviderConfigError) as exc_info: + await provider.validate() + assert "provider_name" in str(exc_info.value) or "access_token" in str(exc_info.value) + + +# --------------------------------------------------------------------------- +# _get_account_with_credentials() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_app_mismatch(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + account_data = {"app": {"name_slug": "wrong_app"}, "name": "acc"} + + mock_get_with_auth = AsyncMock(return_value=account_data) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderConfigError, match="expected 'slack_v2'"): + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_no_credentials_raises_default_oauth(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + account_data = {"app": {"name_slug": "slack_v2"}, "name": "acc"} + + mock_get_with_auth = AsyncMock(return_value=account_data) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(PipedreamDefaultOAuthException) as exc_info: + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + assert exc_info.value.source_short_name == "slack" + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_reraises_config_error(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + err = AuthProviderConfigError("bad config", provider_name="pipedream") + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderConfigError, match="bad config"): + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_reraises_default_oauth(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + err = PipedreamDefaultOAuthException("slack") + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(PipedreamDefaultOAuthException): + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_404_raises_account_not_found(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + resp = _make_response(404, {"error": "not_found"}) + err = httpx.HTTPStatusError("404", request=resp.request, response=resp) + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderAccountNotFoundError) as exc_info: + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + assert exc_info.value.account_id == "acc" + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_429_raises_rate_limit(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + resp = _make_response(429, {"error": "rate_limited"}, headers={"retry-after": "60"}) + err = httpx.HTTPStatusError("429", request=resp.request, response=resp) + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + assert exc_info.value.retry_after == 60.0 + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_5xx_raises_temporary(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + resp = _make_response(503, {"error": "unavailable"}) + err = httpx.HTTPStatusError("503", request=resp.request, response=resp) + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderTemporaryError, match="503"): + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + + +@pytest.mark.asyncio +async def test_get_account_with_credentials_other_status_raises_temporary(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + resp = _make_response(418, text="I'm a teapot") + err = httpx.HTTPStatusError("418", request=resp.request, response=resp) + mock_get_with_auth = AsyncMock(side_effect=err) + + with patch.object(provider, "_get_with_auth", mock_get_with_auth): + with pytest.raises(AuthProviderTemporaryError, match="418"): + await provider._get_account_with_credentials( + AsyncMock(), "slack_v2", "slack" + ) + + +# --------------------------------------------------------------------------- +# _extract_and_map_credentials() missing fields (line 554) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extract_and_map_credentials_missing_required_raises(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + account_data = {"credentials": {"oauth_client_id": "cid"}} + with pytest.raises(AuthProviderMissingFieldsError) as exc_info: + provider._extract_and_map_credentials( + account_data, + source_auth_config_fields=["access_token", "client_id"], + source_short_name="slack", + ) + assert "access_token" in exc_info.value.missing_fields + assert exc_info.value.available_fields == ["oauth_client_id"] + + +@pytest.mark.asyncio +async def test_extract_and_map_credentials_optional_skipped(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + account_data = {"credentials": {"oauth_access_token": "tok"}} + result = provider._extract_and_map_credentials( + account_data, + source_auth_config_fields=["access_token", "refresh_token"], + source_short_name="slack", + optional_fields={"refresh_token"}, + ) + assert result == {"access_token": "tok"} + + +@pytest.mark.asyncio +async def test_extract_and_map_credentials_success(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + account_data = { + "credentials": { + "oauth_access_token": "tok", + "oauth_refresh_token": "ref", + }, + } + result = provider._extract_and_map_credentials( + account_data, + source_auth_config_fields=["access_token", "refresh_token"], + source_short_name="slack", + ) + assert result == {"access_token": "tok", "refresh_token": "ref"} + + +@pytest.mark.asyncio +async def test_extract_and_map_credentials_source_field_mapping_coda(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + account_data = {"credentials": {"api_token": "coda-key"}} + result = provider._extract_and_map_credentials( + account_data, + source_auth_config_fields=["api_key"], + source_short_name="coda", + ) + assert result == {"api_key": "coda-key"} + + +# --------------------------------------------------------------------------- +# get_creds_for_source() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_creds_for_source_success(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + account_data = { + "app": {"name_slug": "slack_v2"}, + "name": "My Slack", + "credentials": {"oauth_access_token": "xoxb-tok"}, + } + + async def fake_get_account(client, slug, source): + return account_data + + with patch.object(provider, "_get_account_with_credentials", fake_get_account): + creds = await provider.get_creds_for_source( + "slack", ["access_token"] + ) + assert creds == {"access_token": "xoxb-tok"} + + +# --------------------------------------------------------------------------- +# get_auth_result() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_auth_result_blocked_source_returns_proxy(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + result = await provider.get_auth_result("github", ["access_token"]) + assert result.requires_proxy + assert result.proxy_config["reason"] == "blocked_source" + assert result.proxy_config["source"] == "github" + + +@pytest.mark.asyncio +async def test_get_auth_result_direct_success(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + account_data = { + "app": {"name_slug": "slack_v2"}, + "name": "My Slack", + "credentials": {"oauth_access_token": "xoxb-tok"}, + } + + async def fake_get_account(client, slug, source): + return account_data + + with patch.object(provider, "_get_account_with_credentials", fake_get_account): + result = await provider.get_auth_result("slack", ["access_token"]) + assert not result.requires_proxy + assert result.credentials == {"access_token": "xoxb-tok"} + + +@pytest.mark.asyncio +async def test_get_auth_result_default_oauth_returns_proxy(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + + async def fake_get_account(client, slug, source): + raise PipedreamDefaultOAuthException("slack") + + with patch.object(provider, "_get_account_with_credentials", fake_get_account): + result = await provider.get_auth_result("slack", ["access_token"]) + assert result.requires_proxy + assert result.proxy_config["reason"] == "default_oauth" + assert result.proxy_config["source"] == "slack" + + +@pytest.mark.asyncio +async def test_get_auth_result_with_source_config_mappings(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "proj", "account_id": "acc"}, + ) + account_data = { + "app": {"name_slug": "slack_v2"}, + "name": "My Slack", + "credentials": {"oauth_access_token": "xoxb-tok"}, + } + + async def fake_get_account(client, slug, source): + return account_data + + async def fake_get_config(source, mappings): + return {"instance_url": "https://example.slack.com"} + + with patch.object(provider, "_get_account_with_credentials", fake_get_account): + with patch.object(provider, "get_config_for_source", fake_get_config): + result = await provider.get_auth_result( + "slack", + ["access_token"], + source_config_field_mappings={"instance_url": "instance_url"}, + ) + assert result.source_config == {"instance_url": "https://example.slack.com"} + + +# --------------------------------------------------------------------------- +# _get_pipedream_app_slug, _map_field_name +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_pipedream_app_slug_mapped(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + assert provider._get_pipedream_app_slug("slack") == "slack_v2" + assert provider._get_pipedream_app_slug("apollo") == "apollo_io" + + +@pytest.mark.asyncio +async def test_get_pipedream_app_slug_unmapped(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + assert provider._get_pipedream_app_slug("gmail") == "gmail" + + +@pytest.mark.asyncio +async def test_map_field_name_default(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + assert provider._map_field_name("access_token") == "oauth_access_token" + + +@pytest.mark.asyncio +async def test_map_field_name_source_override(): + provider = await PipedreamAuthProvider.create( + credentials={"client_id": "cid", "client_secret": "csec"}, + config={"project_id": "p", "account_id": "a"}, + ) + assert provider._map_field_name("api_key", "coda") == "api_token" diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index 72eab1699..3c31200f8 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -37,7 +37,15 @@ ResponseBuilderProtocol, SourceConnectionRepositoryProtocol, ) -from airweave.domains.sources.protocols import SourceRegistryProtocol +from airweave.domains.sources.exceptions import ( + SourceCreationError, + SourceNotFoundError, + SourceValidationError, +) +from airweave.domains.sources.protocols import ( + SourceLifecycleServiceProtocol, + SourceRegistryProtocol, +) from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.syncs.protocols import ( SyncJobRepositoryProtocol, @@ -79,6 +87,7 @@ def __init__( init_session_repo: OAuthInitSessionRepositoryProtocol, response_builder: ResponseBuilderProtocol, source_registry: SourceRegistryProtocol, + source_lifecycle: SourceLifecycleServiceProtocol, sync_lifecycle: SyncLifecycleServiceProtocol, sync_record_service: SyncRecordServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, @@ -97,6 +106,7 @@ def __init__( self._init_session_repo = init_session_repo self._response_builder = response_builder self._source_registry = source_registry + self._source_lifecycle = source_lifecycle self._sync_lifecycle = sync_lifecycle self._sync_record_service = sync_record_service self._temporal_workflow_service = temporal_workflow_service @@ -571,24 +581,19 @@ async def _validate_oauth2_token_or_raise( access_token: str, ctx: ApiContext, ) -> None: - """Validate OAuth2 token using source implementation; fail callback if invalid.""" + """Validate OAuth2 token using source lifecycle service; fail callback if invalid.""" if not source_entry: return try: - source_cls = source_entry.source_class_ref - - source_instance = await source_cls.create(access_token=access_token, config=None) - source_instance.set_logger(ctx.logger) - - if hasattr(source_instance, "validate"): - is_valid = await source_instance.validate() - if not is_valid: - raise HTTPException(status_code=400, detail="OAuth token is invalid") - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=400, detail=f"Token validation failed: {e}") from e + await self._source_lifecycle.validate( + short_name=source_entry.short_name, + credentials=access_token, + ) + except SourceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except (SourceCreationError, SourceValidationError) as e: + raise HTTPException(status_code=400, detail=str(e)) from e # ------------------------------------------------------------------ # Private: finalization (response + sync trigger) diff --git a/backend/airweave/domains/oauth/exceptions.py b/backend/airweave/domains/oauth/exceptions.py new file mode 100644 index 000000000..18fa6e024 --- /dev/null +++ b/backend/airweave/domains/oauth/exceptions.py @@ -0,0 +1,106 @@ +"""OAuth domain exceptions. + +Hierarchy +--------- +OAuthRefreshError — base for all token-refresh failures +├── OAuthRefreshTokenRevokedError — 401 from token endpoint (refresh_token dead) +├── OAuthRefreshBadRequestError — 400 / invalid_grant +├── OAuthRefreshRateLimitError — 429 or exhausted rate-limit retries +├── OAuthRefreshServerError — 5xx / timeout / connection error +└── OAuthRefreshCredentialMissingError — no connection or credential in DB +""" + +from typing import Optional + + +class OAuthRefreshError(Exception): + """Base for all OAuth token-refresh failures. + + Carries ``integration_short_name`` so callers can log/route + without parsing the message. + """ + + def __init__(self, message: str, *, integration_short_name: str = ""): + """Initialize OAuthRefreshError.""" + self.integration_short_name = integration_short_name + super().__init__(message) + + +class OAuthRefreshTokenRevokedError(OAuthRefreshError): + """Token endpoint returned 401 — the refresh_token is expired or revoked. + + The user needs to re-authenticate. + """ + + def __init__( + self, + message: str = "Refresh token revoked or expired", + *, + integration_short_name: str = "", + status_code: int = 401, + ): + """Initialize OAuthRefreshTokenRevokedError.""" + self.status_code = status_code + super().__init__(message, integration_short_name=integration_short_name) + + +class OAuthRefreshBadRequestError(OAuthRefreshError): + """Token endpoint returned 400 — invalid grant or malformed request. + + Typically means the refresh_token is invalid (e.g. already rotated + and the old one was replayed). + """ + + def __init__( + self, + message: str = "Invalid grant or malformed refresh request", + *, + integration_short_name: str = "", + error_code: str = "", + ): + """Initialize OAuthRefreshBadRequestError.""" + self.error_code = error_code + super().__init__(message, integration_short_name=integration_short_name) + + +class OAuthRefreshRateLimitError(OAuthRefreshError): + """Token endpoint is rate-limiting us (429 or equivalent). + + Raised after exhausting internal retries. + """ + + def __init__( + self, + message: str = "OAuth token refresh rate-limited", + *, + integration_short_name: str = "", + retry_after: float = 30.0, + ): + """Initialize OAuthRefreshRateLimitError.""" + self.retry_after = retry_after + super().__init__(message, integration_short_name=integration_short_name) + + +class OAuthRefreshServerError(OAuthRefreshError): + """The token endpoint returned a server error (5xx, timeout, connection error).""" + + def __init__( + self, + message: str = "Token endpoint returned a server error", + *, + integration_short_name: str = "", + status_code: Optional[int] = None, + ): + """Initialize OAuthRefreshServerError.""" + self.status_code = status_code + super().__init__(message, integration_short_name=integration_short_name) + + +class OAuthRefreshCredentialMissingError(OAuthRefreshError): + """Connection or credential record not found in the database. + + Either the connection was deleted, or the credential row is missing. + This is a data-integrity / configuration issue, not a transient failure. + """ + + pass diff --git a/backend/airweave/domains/oauth/fakes/oauth2_service.py b/backend/airweave/domains/oauth/fakes/oauth2_service.py index 4230426bb..796e40e12 100644 --- a/backend/airweave/domains/oauth/fakes/oauth2_service.py +++ b/backend/airweave/domains/oauth/fakes/oauth2_service.py @@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave.api.context import ApiContext +from airweave.domains.oauth.types import RefreshResult from airweave.platform.auth.schemas import OAuth2Settings, OAuth2TokenResponse @@ -189,3 +190,25 @@ async def refresh_access_token( if not resp: raise ValueError(f"No seeded refresh response for {integration_short_name}") return resp + + async def refresh_and_persist( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + config_fields: Optional[dict[str, str]] = None, + ) -> RefreshResult: + """Fake refresh_and_persist — returns seeded RefreshResult.""" + self._calls.append( + ("refresh_and_persist", db, integration_short_name, connection_id, ctx, config_fields) + ) + if self._should_raise: + raise self._should_raise + resp = self._refresh_responses.get(integration_short_name) + if not resp: + raise ValueError(f"No seeded refresh response for {integration_short_name}") + return RefreshResult( + access_token=resp.access_token, + expires_in=resp.expires_in, + ) diff --git a/backend/airweave/domains/oauth/oauth2_service.py b/backend/airweave/domains/oauth/oauth2_service.py index 6e8a95fb4..0f6b6d58b 100644 --- a/backend/airweave/domains/oauth/oauth2_service.py +++ b/backend/airweave/domains/oauth/oauth2_service.py @@ -15,14 +15,23 @@ from airweave import schemas from airweave.api.context import ApiContext from airweave.core.config.settings import Settings -from airweave.core.exceptions import NotFoundException, TokenRefreshError +from airweave.core.exceptions import NotFoundException from airweave.core.logging import ContextualLogger from airweave.core.protocols.encryption import CredentialEncryptor from airweave.core.shared_models import ConnectionStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.connections.protocols import ConnectionRepositoryProtocol from airweave.domains.credentials.protocols import IntegrationCredentialRepositoryProtocol +from airweave.domains.oauth.exceptions import ( + OAuthRefreshBadRequestError, + OAuthRefreshCredentialMissingError, + OAuthRefreshError, + OAuthRefreshRateLimitError, + OAuthRefreshServerError, + OAuthRefreshTokenRevokedError, +) from airweave.domains.oauth.protocols import OAuth2ServiceProtocol +from airweave.domains.oauth.types import RefreshResult from airweave.domains.sources.protocols import SourceRegistryProtocol from airweave.models.integration_credential import IntegrationType from airweave.platform.auth.schemas import ( @@ -312,7 +321,11 @@ async def refresh_access_token( OAuth2TokenResponse: The response containing the new access token and other details. Raises: - TokenRefreshError: If token refresh fails. + OAuthRefreshTokenRevokedError: If the refresh token is expired or revoked (401). + OAuthRefreshBadRequestError: If the token endpoint returns 400. + OAuthRefreshRateLimitError: If rate-limited after exhausting retries. + OAuthRefreshServerError: If the token endpoint returns 5xx or times out. + OAuthRefreshCredentialMissingError: If no refresh token in credentials. NotFoundException: If the integration is not found. """ try: @@ -356,7 +369,13 @@ async def refresh_access_token( ctx.logger, integration_config, refresh_token, client_id, client_secret ) - response = await self._make_token_request(ctx.logger, backend_url, headers, payload) + response = await self._make_token_request( + ctx.logger, + backend_url, + headers, + payload, + integration_short_name=integration_short_name, + ) oauth2_token_response = await self._handle_token_response( db, response, integration_config, ctx, connection_id @@ -364,6 +383,9 @@ async def refresh_access_token( return oauth2_token_response + except OAuthRefreshError: + raise + except Exception as e: ctx.logger.error( f"Token refresh failed for organization {ctx.organization.id} and " @@ -371,6 +393,66 @@ async def refresh_access_token( ) raise + async def refresh_and_persist( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + config_fields: Optional[dict[str, str]] = None, + ) -> RefreshResult: + """Load credentials, refresh token, persist rotation, return new access token. + + Convenience method that combines credential loading with refresh_access_token. + + Args: + db: Database session. + integration_short_name: Source short name. + connection_id: Connection UUID (for credential lookup and rotation persistence). + ctx: API context. + config_fields: Optional config fields for templated backend URLs. + + Returns: + RefreshResult with the new access token and optional expires_in. + + Raises: + OAuthRefreshCredentialMissingError: If connection or credential not found. + OAuthRefreshTokenRevokedError: If the refresh token is expired or revoked. + OAuthRefreshBadRequestError: If the token endpoint returns 400. + OAuthRefreshRateLimitError: If rate-limited after exhausting retries. + OAuthRefreshServerError: If the token endpoint returns 5xx or times out. + """ + connection = await self.conn_repo.get(db=db, id=connection_id, ctx=ctx) + if not connection or not connection.integration_credential_id: + raise OAuthRefreshCredentialMissingError( + f"Connection {connection_id} not found or has no credential", + integration_short_name=integration_short_name, + ) + + credential = await self.cred_repo.get( + db=db, id=connection.integration_credential_id, ctx=ctx + ) + if not credential: + raise OAuthRefreshCredentialMissingError( + "Integration credential not found", + integration_short_name=integration_short_name, + ) + + decrypted = self.encryptor.decrypt(credential.encrypted_credentials) + + response = await self.refresh_access_token( + db=db, + integration_short_name=integration_short_name, + ctx=ctx, + connection_id=connection_id, + decrypted_credential=decrypted, + config_fields=config_fields, + ) + return RefreshResult( + access_token=response.access_token, + expires_in=response.expires_in, + ) + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ @@ -419,13 +501,13 @@ async def _get_refresh_token( """Get refresh token from decrypted credentials. Raises: - TokenRefreshError: If no refresh token is found. + OAuthRefreshCredentialMissingError: If no refresh token is found. """ refresh_token = decrypted_credential.get("refresh_token", None) if not refresh_token: error_message = "No refresh token found" logger.error(error_message) - raise TokenRefreshError(error_message) + raise OAuthRefreshCredentialMissingError(error_message) return refresh_token async def _get_integration_config( @@ -556,9 +638,21 @@ def _is_oauth_rate_limit_error(self, response: httpx.Response) -> bool: return False async def _make_token_request( - self, logger: ContextualLogger, url: str, headers: dict[str, str], payload: dict[str, str] + self, + logger: ContextualLogger, + url: str, + headers: dict[str, str], + payload: dict[str, str], + integration_short_name: str = "", ) -> httpx.Response: - """Make the token refresh request with retry on rate limit.""" + """Make the token refresh request with retry on rate limit. + + Raises: + OAuthRefreshTokenRevokedError: On 401 from the token endpoint. + OAuthRefreshBadRequestError: On 400 / invalid_grant. + OAuthRefreshRateLimitError: After exhausting rate-limit retries. + OAuthRefreshServerError: On 5xx, timeout, or connection error. + """ logger.info(f"Making token request to: {url}") max_retries = 5 @@ -594,9 +688,9 @@ async def _make_token_request( await asyncio.sleep(delay) continue + status = e.response.status_code logger.error( - f"HTTP error during token request: {e.response.status_code} " - f"{e.response.reason_phrase}" + f"HTTP error during token request: {status} {e.response.reason_phrase}" ) try: @@ -605,17 +699,77 @@ async def _make_token_request( except Exception: logger.error(f"Error response text: {e.response.text}") - raise + self._raise_typed_refresh_error(e, integration_short_name=integration_short_name) + + except (httpx.ConnectError, httpx.TimeoutException) as e: + logger.error(f"Connection/timeout error during token request: {e}") + raise OAuthRefreshServerError( + f"Token endpoint unreachable: {e}", + integration_short_name=integration_short_name, + ) from e + except Exception as e: logger.error(f"Unexpected error during token request: {str(e)}") - raise + raise OAuthRefreshServerError( + f"Unexpected error during token refresh: {e}", + integration_short_name=integration_short_name, + ) from e - raise httpx.HTTPStatusError( + raise OAuthRefreshRateLimitError( f"OAuth token request failed after {max_retries} retries (rate limited)", - request=httpx.Request("POST", url), - response=httpx.Response(429), + integration_short_name=integration_short_name, ) + def _raise_typed_refresh_error( + self, + exc: httpx.HTTPStatusError, + *, + integration_short_name: str = "", + ) -> None: + """Translate an httpx.HTTPStatusError into a typed OAuthRefresh exception. + + Always raises — never returns. + """ + status = exc.response.status_code + detail = "" + try: + body = exc.response.json() + detail = body.get("error_description", body.get("error", "")) + except Exception: + detail = exc.response.text[:200] if exc.response.text else "" + + if status == 401: + raise OAuthRefreshTokenRevokedError( + f"Token endpoint returned 401: {detail}", + integration_short_name=integration_short_name, + status_code=status, + ) from exc + + if status == 400 or status == 403: + error_code = "" + try: + error_code = exc.response.json().get("error", "") + except Exception: + pass + raise OAuthRefreshBadRequestError( + f"Token endpoint returned {status}: {detail}", + integration_short_name=integration_short_name, + error_code=error_code, + ) from exc + + if status >= 500: + raise OAuthRefreshServerError( + f"Token endpoint returned {status}: {detail}", + integration_short_name=integration_short_name, + status_code=status, + ) from exc + + raise OAuthRefreshServerError( + f"Token endpoint returned unexpected {status}: {detail}", + integration_short_name=integration_short_name, + status_code=status, + ) from exc + async def _handle_token_response( self, db: AsyncSession, diff --git a/backend/airweave/domains/oauth/protocols.py b/backend/airweave/domains/oauth/protocols.py index 6c80cc11d..bd13bf855 100644 --- a/backend/airweave/domains/oauth/protocols.py +++ b/backend/airweave/domains/oauth/protocols.py @@ -9,7 +9,11 @@ from airweave.api.context import ApiContext, ConnectContext from airweave.core.logging import ContextualLogger from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.oauth.types import OAuth1TokenResponse, OAuthBrowserInitiationResult +from airweave.domains.oauth.types import ( + OAuth1TokenResponse, + OAuthBrowserInitiationResult, + RefreshResult, +) from airweave.models.connection_init_session import ConnectionInitSession from airweave.platform.auth.schemas import OAuth1Settings, OAuth2Settings, OAuth2TokenResponse from airweave.schemas.source_connection import SourceConnection as SourceConnectionSchema @@ -124,6 +128,20 @@ async def refresh_access_token( """Refresh an OAuth2 access token.""" ... + async def refresh_and_persist( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + config_fields: Optional[dict] = None, + ) -> RefreshResult: + """Load credentials, refresh token, persist rotated refresh_token. + + Returns a RefreshResult with the new access_token and optional expires_in. + """ + ... + # --------------------------------------------------------------------------- # Init session + redirect session repositories diff --git a/backend/airweave/domains/oauth/tests/test_callback_service.py b/backend/airweave/domains/oauth/tests/test_callback_service.py index 22975e808..0d37baa9f 100644 --- a/backend/airweave/domains/oauth/tests/test_callback_service.py +++ b/backend/airweave/domains/oauth/tests/test_callback_service.py @@ -24,6 +24,7 @@ from airweave.domains.oauth.types import OAuth1TokenResponse from airweave.domains.organizations.fakes.repository import FakeOrganizationRepository from airweave.domains.source_connections.fakes.repository import FakeSourceConnectionRepository +from airweave.domains.sources.exceptions import SourceNotFoundError, SourceValidationError from airweave.domains.syncs.fakes.sync_job_repository import FakeSyncJobRepository from airweave.domains.syncs.fakes.sync_repository import FakeSyncRepository from airweave.models.connection_init_session import ConnectionInitSession, ConnectionInitStatus @@ -118,6 +119,7 @@ def _service( oauth_flow_service=None, response_builder=None, source_registry=None, + source_lifecycle=None, sync_lifecycle=None, sync_record_service=None, temporal_workflow_service=None, @@ -128,6 +130,7 @@ def _service( init_session_repo=init_session_repo or FakeOAuthInitSessionRepository(), response_builder=response_builder or AsyncMock(), source_registry=source_registry or MagicMock(), + source_lifecycle=source_lifecycle or AsyncMock(), sync_lifecycle=sync_lifecycle or AsyncMock(), sync_record_service=sync_record_service or AsyncMock(), temporal_workflow_service=temporal_workflow_service or AsyncMock(), @@ -249,21 +252,14 @@ async def test_invalid_oauth2_token_fails_fast_with_400(self): OAuth2TokenResponse(access_token="bad-token", token_type="bearer") ) - class _InvalidSource: - def set_logger(self, _logger): - return None - - async def validate(self): - return False - - class _SourceClass: - @staticmethod - async def create(access_token, config): # noqa: ARG004 - return _InvalidSource() - registry = MagicMock() registry.get.return_value = SimpleNamespace( - source_class_ref=_SourceClass, short_name="github" + source_class_ref=MagicMock(), short_name="github", auth_config_ref=None, + ) + + source_lifecycle = AsyncMock() + source_lifecycle.validate = AsyncMock( + side_effect=SourceValidationError("github", "validate() returned False") ) svc = _service( @@ -272,12 +268,13 @@ async def create(access_token, config): # noqa: ARG004 sc_repo=sc_repo, oauth_flow_service=oauth_flow, source_registry=registry, + source_lifecycle=source_lifecycle, ) with pytest.raises(HTTPException) as exc: await svc.complete_oauth2_callback(DB, state="state-abc", code="c") assert exc.value.status_code == 400 - assert "token" in exc.value.detail.lower() + assert "validation failed" in exc.value.detail.lower() assert all(call[0] != "mark_completed" for call in init_repo._calls) async def test_validation_exception_fails_fast_with_400(self): @@ -298,21 +295,14 @@ async def test_validation_exception_fails_fast_with_400(self): OAuth2TokenResponse(access_token="token", token_type="bearer") ) - class _BrokenSource: - def set_logger(self, _logger): - return None - - async def validate(self): - raise RuntimeError("provider error") - - class _SourceClass: - @staticmethod - async def create(access_token, config): # noqa: ARG004 - return _BrokenSource() - registry = MagicMock() registry.get.return_value = SimpleNamespace( - source_class_ref=_SourceClass, short_name="github" + source_class_ref=MagicMock(), short_name="github", auth_config_ref=None, + ) + + source_lifecycle = AsyncMock() + source_lifecycle.validate = AsyncMock( + side_effect=SourceValidationError("github", "validation raised: provider error") ) svc = _service( @@ -321,6 +311,7 @@ async def create(access_token, config): # noqa: ARG004 sc_repo=sc_repo, oauth_flow_service=oauth_flow, source_registry=registry, + source_lifecycle=source_lifecycle, ) with pytest.raises(HTTPException) as exc: await svc.complete_oauth2_callback(DB, state="state-abc", code="c") @@ -329,6 +320,47 @@ async def create(access_token, config): # noqa: ARG004 assert "validation failed" in exc.value.detail.lower() assert all(call[0] != "mark_completed" for call in init_repo._calls) + async def test_source_not_found_fails_with_404(self): + init_repo = FakeOAuthInitSessionRepository() + session = _init_session() + init_repo.seed_by_state("state-abc", session) + + org_repo = FakeOrganizationRepository() + org_repo.seed(ORG_ID, _organization()) + + sc_repo = FakeSourceConnectionRepository() + shell = _source_conn_shell() + sc_repo.seed(shell.id, shell) + sc_repo.seed_init_session(SESSION_ID, session) + + oauth_flow = FakeOAuthFlowService() + oauth_flow.seed_oauth2_response( + OAuth2TokenResponse(access_token="bad-token", token_type="bearer") + ) + + registry = MagicMock() + registry.get.return_value = SimpleNamespace( + source_class_ref=MagicMock(), short_name="github", auth_config_ref=None, + ) + + source_lifecycle = AsyncMock() + source_lifecycle.validate = AsyncMock( + side_effect=SourceNotFoundError("unknown_source") + ) + + svc = _service( + init_session_repo=init_repo, + organization_repo=org_repo, + sc_repo=sc_repo, + oauth_flow_service=oauth_flow, + source_registry=registry, + source_lifecycle=source_lifecycle, + ) + with pytest.raises(HTTPException) as exc: + await svc.complete_oauth2_callback(DB, state="state-abc", code="c") + + assert exc.value.status_code == 404 + async def test_happy_path_delegates_and_finalizes(self): init_repo = FakeOAuthInitSessionRepository() session = _init_session() diff --git a/backend/airweave/domains/oauth/tests/test_oauth2_service.py b/backend/airweave/domains/oauth/tests/test_oauth2_service.py index 8d5c30a4a..3b04fdb54 100644 --- a/backend/airweave/domains/oauth/tests/test_oauth2_service.py +++ b/backend/airweave/domains/oauth/tests/test_oauth2_service.py @@ -38,10 +38,18 @@ from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException, TokenRefreshError from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, IntegrationType +from airweave.core.shared_models import AuthMethod, ConnectionStatus, IntegrationType from airweave.domains.connections.fakes.repository import FakeConnectionRepository from airweave.domains.credentials.fakes.repository import FakeIntegrationCredentialRepository +from airweave.domains.oauth.exceptions import ( + OAuthRefreshBadRequestError, + OAuthRefreshCredentialMissingError, + OAuthRefreshRateLimitError, + OAuthRefreshServerError, + OAuthRefreshTokenRevokedError, +) from airweave.domains.oauth.oauth2_service import OAuth2Service +from airweave.domains.oauth.types import RefreshResult from airweave.domains.sources.fakes.registry import FakeSourceRegistry from airweave.models.connection import Connection from airweave.models.integration_credential import IntegrationCredential @@ -1150,14 +1158,14 @@ class RefreshCase: expect_access_token="new-at", ), RefreshCase( - "no refresh token → TokenRefreshError", + "no refresh token → OAuthRefreshCredentialMissingError", {"access_token": "only-at"}, True, "with_refresh", 200, {}, True, - TokenRefreshError, + OAuthRefreshCredentialMissingError, ), RefreshCase( "integration config not found → NotFoundException", @@ -1448,7 +1456,7 @@ async def fake_post(url, headers=None, data=None): @pytest.mark.asyncio async def test_make_token_request_non_retryable_error_raises(): - """500 error (not rate limit) → raises immediately.""" + """500 error (not rate limit) → raises OAuthRefreshServerError.""" svc = _svc() log = logger.with_context(test="non_retry") @@ -1465,10 +1473,266 @@ async def fake_post(url, headers=None, data=None): "airweave.domains.oauth.oauth2_service.httpx.AsyncClient", return_value=mock_client, ): - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(OAuthRefreshServerError): await svc._make_token_request(log, "https://p.com/token", {}, {}) +# =========================================================================== +# Exception __init__ coverage (exceptions.py:62-63, 80-81) +# =========================================================================== + + +def test_oauth_refresh_bad_request_error_stores_error_code(): + err = OAuthRefreshBadRequestError("msg", error_code="invalid_grant") + assert err.error_code == "invalid_grant" + assert str(err) == "msg" + + +def test_oauth_refresh_rate_limit_error_stores_retry_after(): + err = OAuthRefreshRateLimitError("msg", retry_after=60.0) + assert err.retry_after == 60.0 + assert str(err) == "msg" + + +# =========================================================================== +# refresh_and_persist (oauth2_service.py:425-451) +# =========================================================================== + + +@pytest.mark.asyncio +async def test_refresh_and_persist_no_connection_raises(): + """Connection not found → OAuthRefreshCredentialMissingError.""" + deps = Deps() + svc = deps.build() + ctx = _make_ctx() + with pytest.raises(OAuthRefreshCredentialMissingError): + await svc.refresh_and_persist(MagicMock(), "slack", uuid4(), ctx) + + +@pytest.mark.asyncio +async def test_refresh_and_persist_no_credential_raises(): + """Connection exists but credential not found → OAuthRefreshCredentialMissingError.""" + deps = Deps() + conn_id = uuid4() + cred_id = uuid4() + conn = Connection( + id=conn_id, + name="test-conn", + readable_id="test-conn-001", + short_name="slack", + integration_type=IntegrationType.SOURCE, + status=ConnectionStatus.ACTIVE, + organization_id=ORG_ID, + created_by_email="test@test.com", + modified_by_email="test@test.com", + integration_credential_id=cred_id, + ) + deps.conn_repo.seed(conn_id, conn) + svc = deps.build() + ctx = _make_ctx() + with pytest.raises(OAuthRefreshCredentialMissingError): + await svc.refresh_and_persist(MagicMock(), "slack", conn_id, ctx) + + +@pytest.mark.asyncio +async def test_refresh_and_persist_happy_path(): + """Full flow: loads connection, decrypts credential, refreshes, returns RefreshResult.""" + deps = Deps() + conn_id = uuid4() + cred_id = uuid4() + conn = Connection( + id=conn_id, + name="test-conn", + readable_id="test-conn-002", + short_name="slack", + integration_type=IntegrationType.SOURCE, + status=ConnectionStatus.ACTIVE, + organization_id=ORG_ID, + created_by_email="test@test.com", + modified_by_email="test@test.com", + integration_credential_id=cred_id, + ) + deps.conn_repo.seed(conn_id, conn) + + encrypted = deps.encryptor.encrypt({"access_token": "old", "refresh_token": "rt"}) + cred = IntegrationCredential( + id=cred_id, + name="test-cred", + integration_short_name="slack", + integration_type=IntegrationType.SOURCE, + authentication_method=AuthenticationMethod.OAUTH_TOKEN, + organization_id=ORG_ID, + created_by_email="test@test.com", + modified_by_email="test@test.com", + encrypted_credentials=encrypted, + ) + deps.cred_repo.seed(cred_id, cred) + + svc = deps.build() + mock_response = OAuth2TokenResponse( + access_token="new-tok", token_type="bearer", expires_in=3600 + ) + svc.refresh_access_token = AsyncMock(return_value=mock_response) + + ctx = _make_ctx() + result = await svc.refresh_and_persist(MagicMock(), "slack", conn_id, ctx) + assert result.access_token == "new-tok" + assert result.expires_in == 3600 + + +# =========================================================================== +# _make_token_request error branches (oauth2_service.py:704-718) +# =========================================================================== + + +@pytest.mark.asyncio +async def test_make_token_request_connect_error_raises_server_error(): + """httpx.ConnectError → OAuthRefreshServerError.""" + svc = _svc() + log = logger.with_context(test="connect_err") + + async def fake_post(url, headers=None, data=None): + raise httpx.ConnectError("Connection refused") + + mock_client = AsyncMock() + mock_client.post = fake_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch( + "airweave.domains.oauth.oauth2_service.httpx.AsyncClient", + return_value=mock_client, + ): + with pytest.raises(OAuthRefreshServerError, match="unreachable"): + await svc._make_token_request(log, "https://p.com/token", {}, {}) + + +@pytest.mark.asyncio +async def test_make_token_request_timeout_raises_server_error(): + """httpx.TimeoutException → OAuthRefreshServerError.""" + svc = _svc() + log = logger.with_context(test="timeout") + + async def fake_post(url, headers=None, data=None): + raise httpx.TimeoutException("timed out") + + mock_client = AsyncMock() + mock_client.post = fake_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch( + "airweave.domains.oauth.oauth2_service.httpx.AsyncClient", + return_value=mock_client, + ): + with pytest.raises(OAuthRefreshServerError, match="unreachable"): + await svc._make_token_request(log, "https://p.com/token", {}, {}) + + +@pytest.mark.asyncio +async def test_make_token_request_generic_exception_raises_server_error(): + """Random exception → OAuthRefreshServerError.""" + svc = _svc() + log = logger.with_context(test="generic") + + async def fake_post(url, headers=None, data=None): + raise RuntimeError("something unexpected") + + mock_client = AsyncMock() + mock_client.post = fake_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch( + "airweave.domains.oauth.oauth2_service.httpx.AsyncClient", + return_value=mock_client, + ): + with pytest.raises(OAuthRefreshServerError, match="Unexpected"): + await svc._make_token_request(log, "https://p.com/token", {}, {}) + + +@pytest.mark.asyncio +async def test_make_token_request_exhausts_retries_on_rate_limit(): + """All retries fail with 429 → OAuthRefreshRateLimitError.""" + svc = _svc() + log = logger.with_context(test="exhaust_retries") + + async def fake_post(url, headers=None, data=None): + resp = _make_httpx_response(429, {"error": "rate_limited"}) + raise httpx.HTTPStatusError("429", request=resp.request, response=resp) + + mock_client = AsyncMock() + mock_client.post = fake_post + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with ( + patch( + "airweave.domains.oauth.oauth2_service.httpx.AsyncClient", + return_value=mock_client, + ), + patch("airweave.domains.oauth.oauth2_service.asyncio.sleep", new_callable=AsyncMock), + ): + with pytest.raises(OAuthRefreshRateLimitError): + await svc._make_token_request(log, "https://p.com/token", {}, {}) + + +# =========================================================================== +# _raise_typed_refresh_error branches (oauth2_service.py:738-767) +# =========================================================================== + + +def test_raise_typed_refresh_error_401_raises_revoked(): + svc = _svc() + resp = _make_httpx_response(401, {"error": "invalid_token"}) + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshTokenRevokedError): + svc._raise_typed_refresh_error(exc) + + +def test_raise_typed_refresh_error_400_raises_bad_request(): + svc = _svc() + resp = _make_httpx_response(400, {"error": "invalid_grant", "error_description": "expired"}) + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshBadRequestError) as exc_info: + svc._raise_typed_refresh_error(exc) + assert exc_info.value.error_code == "invalid_grant" + + +def test_raise_typed_refresh_error_403_raises_bad_request(): + svc = _svc() + resp = _make_httpx_response(403, {"error": "forbidden"}) + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshBadRequestError): + svc._raise_typed_refresh_error(exc) + + +def test_raise_typed_refresh_error_500_raises_server_error(): + svc = _svc() + resp = _make_httpx_response(500, {"error": "internal"}) + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshServerError): + svc._raise_typed_refresh_error(exc) + + +def test_raise_typed_refresh_error_unexpected_status_raises_server_error(): + """418 or any other status → OAuthRefreshServerError.""" + svc = _svc() + resp = _make_httpx_response(418, {"error": "teapot"}) + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshServerError, match="unexpected 418"): + svc._raise_typed_refresh_error(exc) + + +def test_raise_typed_refresh_error_non_json_body_still_works(): + """Response body is not JSON — detail extraction fallback to .text.""" + svc = _svc() + resp = _make_httpx_response(401, text="Not authorized") + exc = httpx.HTTPStatusError("err", request=resp.request, response=resp) + with pytest.raises(OAuthRefreshTokenRevokedError): + svc._raise_typed_refresh_error(exc) + + # =========================================================================== # _get_redirect_url — uses injected settings # =========================================================================== diff --git a/backend/airweave/domains/oauth/types.py b/backend/airweave/domains/oauth/types.py index c1a1847ba..d2452cf99 100644 --- a/backend/airweave/domains/oauth/types.py +++ b/backend/airweave/domains/oauth/types.py @@ -27,3 +27,15 @@ class OAuthBrowserInitiationResult: client_secret: Optional[str] oauth_client_mode: str additional_overrides: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True, slots=True) +class RefreshResult: + """Result of an OAuth2 token refresh. + + Carries the new access token alongside the provider-reported lifetime + so callers can schedule the next refresh accurately. + """ + + access_token: str + expires_in: Optional[int] = None diff --git a/backend/airweave/domains/sources/exceptions.py b/backend/airweave/domains/sources/exceptions.py index a9b23380b..a3622bf6b 100644 --- a/backend/airweave/domains/sources/exceptions.py +++ b/backend/airweave/domains/sources/exceptions.py @@ -1,29 +1,334 @@ -"""Source domain exceptions.""" +"""Source domain exceptions. -from airweave.core.exceptions import NotFoundException +Hierarchy +--------- +NotFoundException +└── SourceNotFoundError — source short_name not in registry + +SourceError (AirweaveException) — base for ALL source runtime errors +├── SourceCreationError — source_class.create() failed +├── SourceValidationError — source.validate() returned False / raised +│ +│ Runtime errors (during generate_entities / search / ACL / browse / tool calls) +├── SourceAuthError — 401 after token refresh attempt → abort sync +│ └── SourceTokenRefreshError — token refresh itself failed +├── SourceRateLimitError — 429 from upstream API → retry with backoff +├── SourceServerError — upstream server error (5xx / timeout / connection) +│ +│ Per-entity errors (skip the entity, continue the sync) +├── SourceEntityError — base for single-entity failures +│ ├── SourceEntityForbiddenError — 403 on one entity +│ ├── SourceEntityNotFoundError — 404 on one entity +│ └── SourceEntitySkippedError — source decided to skip (too large, unsupported, etc.) +│ +│ File download errors +└── SourceFileDownloadError — file download failed for an entity +""" + +from typing import Optional + +from airweave.core.exceptions import ( + AirweaveException, + NotFoundException, +) + +# --------------------------------------------------------------------------- +# Lifecycle exceptions (registry / creation / validation) +# --------------------------------------------------------------------------- class SourceNotFoundError(NotFoundException): """Raised when a source with the given short_name does not exist or is hidden.""" def __init__(self, short_name: str): + """Create a new SourceNotFoundError. + + Args: + short_name: The source short_name that was not found. + """ self.short_name = short_name super().__init__(f"Source not found: {short_name}") -class SourceCreationError(Exception): - """Raised when source_class.create() fails (bad credential format, missing fields, etc.).""" +# --------------------------------------------------------------------------- +# Runtime base +# --------------------------------------------------------------------------- + + +class SourceError(AirweaveException): + """Base for all source runtime errors. + + Every subclass carries source_short_name so the pipeline and + orchestrator can log and route without inspecting the message. + """ + + def __init__(self, message: str, *, source_short_name: str = ""): + """Create a new SourceError. + + Args: + message: Human-readable error description. + source_short_name: Identifier of the source that raised the error. + """ + self.source_short_name = source_short_name + super().__init__(message) + + +# --------------------------------------------------------------------------- +# Lifecycle errors (inherit from SourceError so callers can catch broadly) +# --------------------------------------------------------------------------- + + +class SourceCreationError(SourceError): + """source_class.create() failed (bad credentials, missing fields, etc.).""" def __init__(self, short_name: str, reason: str): + """Create a new SourceCreationError. + + Args: + short_name: Source identifier. + reason: Why creation failed. + """ self.short_name = short_name self.reason = reason - super().__init__(f"Failed to create source '{short_name}': {reason}") + super().__init__( + f"Failed to create source '{short_name}': {reason}", + source_short_name=short_name, + ) -class SourceValidationError(Exception): - """Raised when source.validate() fails or returns False.""" +class SourceValidationError(SourceError): + """source.validate() returned False or raised.""" def __init__(self, short_name: str, reason: str): + """Create a new SourceValidationError. + + Args: + short_name: Source identifier. + reason: Why validation failed. + """ self.short_name = short_name self.reason = reason - super().__init__(f"Validation failed for source '{short_name}': {reason}") + super().__init__( + f"Validation failed for source '{short_name}': {reason}", + source_short_name=short_name, + ) + + +# --------------------------------------------------------------------------- +# Auth errors +# --------------------------------------------------------------------------- + + +class SourceAuthError(SourceError): + """401 Unauthorized after token refresh attempt. + + The pipeline should abort the sync — credentials are invalid or revoked. + """ + + def __init__( + self, + message: str = "Authentication failed after token refresh", + *, + source_short_name: str = "", + status_code: int = 401, + ): + """Create a new SourceAuthError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + status_code: HTTP status code that triggered this (usually 401). + """ + self.status_code = status_code + super().__init__(message, source_short_name=source_short_name) + + +class SourceTokenRefreshError(SourceAuthError): + """Token refresh failed — the underlying OAuth or auth-provider call did not succeed. + + Raised by TokenProvider implementations when refresh is attempted but fails. + Subclass of SourceAuthError so pipelines that catch auth errors broadly + will also catch refresh failures. + """ + + def __init__( + self, + message: str = "Token refresh failed", + *, + source_short_name: str = "", + ): + """Create a new SourceTokenRefreshError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + """ + super().__init__(message, source_short_name=source_short_name, status_code=401) + + +# --------------------------------------------------------------------------- +# Rate limiting +# --------------------------------------------------------------------------- + + +class SourceRateLimitError(SourceError): + """429 Too Many Requests from the upstream API. + + The retry decorator / pipeline should wait ``retry_after`` seconds + before retrying the request. + """ + + def __init__( + self, + *, + retry_after: float, + source_short_name: str = "", + message: Optional[str] = None, + ): + """Create a new SourceRateLimitError. + + Args: + retry_after: Seconds to wait before retrying. + source_short_name: Source identifier. + message: Optional custom message. + """ + self.retry_after = retry_after + msg = message or f"Rate limited — retry after {retry_after:.1f}s" + super().__init__(msg, source_short_name=source_short_name) + + +# --------------------------------------------------------------------------- +# Upstream server errors +# --------------------------------------------------------------------------- + + +class SourceServerError(SourceError): + """Upstream server error (5xx, timeout, connection reset, or other non-auth failure).""" + + def __init__( + self, + message: str = "Upstream server error", + *, + source_short_name: str = "", + status_code: Optional[int] = None, + ): + """Create a new SourceServerError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + status_code: HTTP status code if applicable. + """ + self.status_code = status_code + super().__init__(message, source_short_name=source_short_name) + + +SourceTemporaryError = SourceServerError +SourcePermanentError = SourceServerError + + +# --------------------------------------------------------------------------- +# Per-entity errors (skip the entity, continue the sync) +# --------------------------------------------------------------------------- + + +class SourceEntityError(SourceError): + """Base for errors tied to a single entity. + + The pipeline should skip this entity and continue processing. + """ + + def __init__( + self, + message: str, + *, + source_short_name: str = "", + entity_id: str = "", + ): + """Create a new SourceEntityError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + entity_id: ID of the entity that failed. + """ + self.entity_id = entity_id + super().__init__(message, source_short_name=source_short_name) + + +class SourceEntityForbiddenError(SourceEntityError): + """403 Forbidden when accessing a specific entity. + + Common cause: the OAuth token lacks permission for this particular + resource (e.g. a private channel the bot isn't in). Skip and continue. + """ + + pass + + +class SourceEntityNotFoundError(SourceEntityError): + """404 Not Found for a specific entity. + + The entity was deleted or moved between listing and fetching. Skip. + """ + + pass + + +class SourceEntitySkippedError(SourceEntityError): + """Source intentionally skipped an entity (too large, unsupported type, etc.). + + Not an error per se, but the pipeline should count it as skipped. + """ + + def __init__( + self, + message: str = "Entity skipped by source", + *, + source_short_name: str = "", + entity_id: str = "", + reason: str = "", + ): + """Create a new SourceEntitySkippedError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + entity_id: ID of the entity that was skipped. + reason: Why the entity was skipped. + """ + self.reason = reason + super().__init__(message, source_short_name=source_short_name, entity_id=entity_id) + + +# --------------------------------------------------------------------------- +# File download errors +# --------------------------------------------------------------------------- + + +class SourceFileDownloadError(SourceEntityError): + """File download failed for an entity. + + Treated as a per-entity skip — the sync continues. + """ + + def __init__( + self, + message: str = "File download failed", + *, + source_short_name: str = "", + status_code: Optional[int] = None, + entity_id: str = "", + file_url: str = "", + ): + """Create a new SourceFileDownloadError. + + Args: + message: Human-readable error description. + source_short_name: Source identifier. + status_code: HTTP status code if applicable. + entity_id: ID of the entity whose file failed to download. + file_url: URL that was attempted. + """ + self.file_url = file_url + super().__init__(message, source_short_name=source_short_name, entity_id=entity_id) diff --git a/backend/airweave/domains/sources/lifecycle.py b/backend/airweave/domains/sources/lifecycle.py index 04261661d..0f7bb70ee 100644 --- a/backend/airweave/domains/sources/lifecycle.py +++ b/backend/airweave/domains/sources/lifecycle.py @@ -17,7 +17,10 @@ from airweave.core.exceptions import NotFoundException from airweave.core.logging import ContextualLogger from airweave.core.shared_models import FeatureFlag +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.auth_result import AuthProviderMode from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol +from airweave.domains.auth_provider.providers.pipedream import PipedreamAuthProvider from airweave.domains.connections.protocols import ConnectionRepositoryProtocol from airweave.domains.credentials.protocols import ( IntegrationCredentialRepositoryProtocol, @@ -35,15 +38,13 @@ SourceLifecycleServiceProtocol, SourceRegistryProtocol, ) -from airweave.domains.sources.types import AuthConfig, SourceConnectionData -from airweave.platform.auth_providers._base import BaseAuthProvider -from airweave.platform.auth_providers.auth_result import AuthProviderMode -from airweave.platform.auth_providers.pipedream import PipedreamAuthProvider +from airweave.domains.sources.token_providers.auth_provider import AuthProviderTokenProvider +from airweave.domains.sources.token_providers.oauth import OAuthTokenProvider +from airweave.domains.sources.token_providers.static import StaticTokenProvider +from airweave.domains.sources.types import AuthConfig, SourceConnectionData, SourceRegistryEntry from airweave.platform.http_client import PipedreamProxyClient from airweave.platform.http_client.airweave_client import AirweaveHttpClient from airweave.platform.sources._base import BaseSource -from airweave.platform.sync.token_manager import TokenManager -from airweave.schemas.source_connection import OAuthType SourceCredentials = Union[str, dict, BaseModel] @@ -121,9 +122,10 @@ async def create( ) # 3. Process credentials for source consumption - source_credentials = self._process_credentials_for_source( + entry = self._source_registry.get(source_connection_data.short_name) + source_credentials = self._normalize_credentials( raw_credentials=auth_config.credentials, - source_connection_data=source_connection_data, + entry=entry, logger=logger, ) @@ -135,10 +137,7 @@ async def create( # 5. Configure source self._configure_logger(source, logger) self._configure_http_client_factory(source, auth_config) - self._configure_sync_identifiers(source, source_connection_data, ctx) - - await self._configure_token_manager( - db=db, + await self._configure_token_provider( source=source, source_connection_data=source_connection_data, source_credentials=auth_config.credentials, @@ -166,7 +165,9 @@ async def validate( ) -> None: """Validate credentials by creating a lightweight source and calling .validate(). - No token manager, no HTTP wrapping, no rate limiting — just create + validate. + No HTTP wrapping, no rate limiting — just create + validate. + The source's own access_token (set during create()) is used for + any API calls made during validation. Raises: SourceNotFoundError: If source short_name is not in the registry. @@ -180,8 +181,10 @@ async def validate( source_class = entry.source_class_ref + normalized = self._normalize_credentials(credentials, entry) + try: - source = await source_class.create(credentials, config=config) + source = await source_class.create(normalized, config=config) except Exception as exc: raise SourceCreationError(short_name, str(exc)) from exc @@ -559,53 +562,80 @@ async def _handle_auth_config_credentials( # Private: credential processing # ------------------------------------------------------------------ - def _process_credentials_for_source( + def _normalize_credentials( self, - raw_credentials: Union[dict, BaseModel], - source_connection_data: SourceConnectionData, - logger: ContextualLogger, + raw_credentials: Union[dict, BaseModel, str], + entry: SourceRegistryEntry, + logger: Optional[ContextualLogger] = None, ) -> SourceCredentials: - """Process raw credentials into the format expected by the source. + """Normalize raw credentials into the format expected by source.create(). Handles three cases: 1. OAuth sources without auth_config_class: Extract just the access_token string 2. Sources with auth_config_class and dict credentials: Convert to auth config object 3. Other sources: Pass through as-is + + Used by both the full create() path and the lightweight validate() path. """ - short_name = source_connection_data.short_name - oauth_type = source_connection_data.oauth_type + if isinstance(raw_credentials, str): + return raw_credentials - # Case 1: OAuth sources without auth_config_class need just the access_token string - entry = self._source_registry.get(short_name) - auth_config_ref = entry.auth_config_ref + creds_dict = self._to_creds_dict(raw_credentials, entry.short_name, logger) + if creds_dict is None: + return raw_credentials - if not auth_config_ref and oauth_type: - if isinstance(raw_credentials, dict) and "access_token" in raw_credentials: - logger.debug(f"Extracting access_token for OAuth source {short_name}") - return raw_credentials["access_token"] - elif isinstance(raw_credentials, str): - logger.debug(f"OAuth source {short_name} credentials already a string token") - return raw_credentials - else: - logger.warning( - f"OAuth source {short_name} credentials not in expected format: " - f"{type(raw_credentials)}" - ) - return raw_credentials + return self._process_creds_dict(creds_dict, raw_credentials, entry, logger) - # Case 2: Sources with auth_config_class and dict credentials - if auth_config_ref and isinstance(raw_credentials, dict): + @staticmethod + def _to_creds_dict( + raw_credentials: Union[dict, BaseModel, str], + short_name: str, + logger: Optional[ContextualLogger], + ) -> Optional[dict]: + """Convert raw credentials to a dict, or None if the type is unexpected.""" + if isinstance(raw_credentials, BaseModel): + return raw_credentials.model_dump() + if isinstance(raw_credentials, dict): + return raw_credentials + if logger: + logger.warning( + f"Source {short_name} credentials in unexpected format: {type(raw_credentials)}" + ) + return None + + @staticmethod + def _process_creds_dict( + creds_dict: dict, + raw_credentials: Union[dict, BaseModel, str], + entry: SourceRegistryEntry, + logger: Optional[ContextualLogger], + ) -> SourceCredentials: + """Process a credentials dict according to the source registry entry.""" + auth_config_ref = entry.auth_config_ref + short_name = entry.short_name + + if not auth_config_ref and entry.oauth_type: + if "access_token" in creds_dict: + if logger: + logger.debug(f"Extracting access_token for OAuth source {short_name}") + return creds_dict["access_token"] + if logger: + logger.warning(f"OAuth source {short_name} credentials missing access_token") + return raw_credentials + + if auth_config_ref: try: - processed_credentials = auth_config_ref.model_validate(raw_credentials) - logger.debug( - f"Converted credentials dict to {auth_config_ref.__name__} for {short_name}" - ) - return processed_credentials + validated = auth_config_ref.model_validate(creds_dict) + if logger: + logger.debug( + f"Converted credentials dict to {auth_config_ref.__name__} for {short_name}" + ) + return validated except Exception as e: - logger.error(f"Failed to convert credentials to auth config object: {e}") + if logger: + logger.error(f"Failed to convert credentials to auth config: {e}") raise - # Case 3: Pass through as-is return raw_credentials # ------------------------------------------------------------------ @@ -622,24 +652,8 @@ def _configure_http_client_factory(source: BaseSource, auth_config: AuthConfig) if auth_config.http_client_factory: source.set_http_client_factory(auth_config.http_client_factory) - @staticmethod - def _configure_sync_identifiers( - source: BaseSource, source_connection_data: SourceConnectionData, ctx: ApiContext - ) -> None: - try: - organization_id = ctx.organization.id - sc_id = source_connection_data.source_connection_id - if hasattr(source, "set_sync_identifiers") and sc_id: - source.set_sync_identifiers( - organization_id=str(organization_id), - source_connection_id=str(sc_id), - ) - except Exception: - pass # Non-fatal for older sources - - @staticmethod - async def _configure_token_manager( - db: AsyncSession, + async def _configure_token_provider( + self, source: BaseSource, source_connection_data: SourceConnectionData, source_credentials: SourceCredentials, @@ -648,65 +662,48 @@ async def _configure_token_manager( access_token: Optional[str], auth_config: AuthConfig, ) -> None: - """Set up token manager for OAuth sources that support refresh.""" + """Set up the appropriate TokenProvider for this source.""" auth_mode = auth_config.auth_mode auth_provider_instance: Optional[BaseAuthProvider] = auth_config.auth_provider_instance + short_name = source_connection_data.short_name if access_token is not None: - logger.debug( - f"Skipping token manager for {source_connection_data.short_name} " - f"— direct token injection" + source.set_token_provider( + StaticTokenProvider(access_token, source_short_name=short_name) ) return if auth_mode == AuthProviderMode.PROXY: - logger.info( - f"Skipping token manager for {source_connection_data.short_name} — proxy mode" - ) return - short_name = source_connection_data.short_name oauth_type = source_connection_data.oauth_type - if not oauth_type: return - if oauth_type not in (OAuthType.WITH_REFRESH, OAuthType.WITH_ROTATING_REFRESH): - logger.debug( - f"Skipping token manager for {short_name} — " - f"oauth_type={oauth_type} does not support refresh" - ) - return - try: - minimal_connection = type( - "SourceConnection", - (), - { - "id": source_connection_data.connection_id, - "integration_credential_id": source_connection_data.integration_credential_id, - "config_fields": source_connection_data.config_fields, - }, - )() - - token_manager = TokenManager( - db=db, - source_short_name=short_name, - source_connection=minimal_connection, - ctx=ctx, - initial_credentials=source_credentials, - is_direct_injection=False, - logger_instance=logger, - auth_provider_instance=auth_provider_instance, - ) - source.set_token_manager(token_manager) + if auth_provider_instance: + token_provider = AuthProviderTokenProvider( + auth_provider_instance=auth_provider_instance, + source_short_name=short_name, + source_registry=self._source_registry, + logger=logger, + ) + else: + token_provider = OAuthTokenProvider( + credentials=source_credentials, + oauth_type=oauth_type, + oauth2_service=self._oauth2_service, + source_short_name=short_name, + connection_id=source_connection_data.connection_id, + ctx=ctx, + logger=logger, + config_fields=source_connection_data.config_fields, + ) + + source.set_token_provider(token_provider) - logger.info( - f"Token manager initialized for OAuth source {short_name} " - f"(auth_provider: {'Yes' if auth_provider_instance else 'None'})" - ) except Exception as e: - logger.error(f"Failed to setup token manager for '{short_name}': {e}") + raise SourceCreationError(short_name, f"token provider setup failed: {e}") from e # ------------------------------------------------------------------ # Private: rate limiting wrapper diff --git a/backend/airweave/domains/sources/tests/test_lifecycle.py b/backend/airweave/domains/sources/tests/test_lifecycle.py index f82e7431f..f8dc73210 100644 --- a/backend/airweave/domains/sources/tests/test_lifecycle.py +++ b/backend/airweave/domains/sources/tests/test_lifecycle.py @@ -28,7 +28,7 @@ from airweave.domains.sources.lifecycle import SourceLifecycleService from airweave.domains.sources.tests.conftest import _make_ctx, _make_entry from airweave.domains.sources.types import AuthConfig, SourceConnectionData -from airweave.platform.auth_providers.auth_result import AuthProviderMode +from airweave.domains.auth_provider.auth_result import AuthProviderMode from airweave.platform.configs._base import Fields @@ -48,7 +48,7 @@ async def create(cls, credentials, config=None): instance._credentials = credentials instance._config = config instance._logger = None - instance._token_manager = None + instance._token_provider = None instance._sync_org_id = None instance._sync_sc_id = None return instance @@ -62,12 +62,8 @@ def set_logger(self, logger): def set_http_client_factory(self, factory): self._http_client_factory = factory - def set_sync_identifiers(self, organization_id, source_connection_id): - self._sync_org_id = organization_id - self._sync_sc_id = source_connection_id - - def set_token_manager(self, tm): - self._token_manager = tm + def set_token_provider(self, tp): + self._token_provider = tp class _StubSourceValidateFalse: @@ -626,7 +622,7 @@ async def test_handle_auth_config_oauth2_error_propagates(): # =========================================================================== -# _process_credentials_for_source() — table-driven +# _normalize_credentials() — table-driven # =========================================================================== @@ -658,7 +654,7 @@ class ProcessCredsCase: @pytest.mark.parametrize("case", PROCESS_CREDS_TABLE, ids=lambda c: c.id) -def test_process_credentials_for_source(case: ProcessCredsCase): +def test_normalize_credentials(case: ProcessCredsCase): mock_config_class = MagicMock() mock_config_class.__name__ = "TestAuthConfig" @@ -668,30 +664,30 @@ def test_process_credentials_for_source(case: ProcessCredsCase): mock_config_class.model_validate.return_value = "VALIDATED" entry = _entry_with_class("src", _StubSourceValid) + if case.oauth_type: + object.__setattr__(entry, "oauth_type", case.oauth_type) if case.has_auth_config_ref: object.__setattr__(entry, "auth_config_ref", mock_config_class) service = _make_service(source_entries=[entry]) ctx = _make_ctx() - data = _sc_data(short_name="src", oauth_type=case.oauth_type, - auth_config_class="TestAuthConfig" if case.has_auth_config_ref else None) if case.expect_error: with pytest.raises(case.expect_error): - service._process_credentials_for_source( + service._normalize_credentials( raw_credentials=case.raw_credentials, - source_connection_data=data, logger=ctx.logger, + entry=entry, logger=ctx.logger, ) else: - result = service._process_credentials_for_source( + result = service._normalize_credentials( raw_credentials=case.raw_credentials, - source_connection_data=data, logger=ctx.logger, + entry=entry, logger=ctx.logger, ) assert result == case.expected # =========================================================================== -# _configure_token_manager() — table-driven +# _configure_token_provider() — table-driven # =========================================================================== @@ -705,10 +701,12 @@ class TokenManagerCase: TOKEN_MANAGER_TABLE = [ - TokenManagerCase(id="skip-direct-injection", access_token="injected"), + TokenManagerCase(id="direct-injection", access_token="injected", + expect_tm_set=True), TokenManagerCase(id="skip-proxy-mode", auth_mode=AuthProviderMode.PROXY), TokenManagerCase(id="skip-no-oauth-type", oauth_type=None), - TokenManagerCase(id="skip-access-only-oauth", oauth_type="access_only"), + TokenManagerCase(id="access-only-no-refresh", oauth_type="access_only", + expect_tm_set=True), TokenManagerCase(id="happy-with-refresh", oauth_type="with_refresh", expect_tm_set=True), TokenManagerCase(id="happy-rotating-refresh", oauth_type="with_rotating_refresh", @@ -718,7 +716,7 @@ class TokenManagerCase: @pytest.mark.parametrize("case", TOKEN_MANAGER_TABLE, ids=lambda c: c.id) @pytest.mark.asyncio -async def test_configure_token_manager(case: TokenManagerCase): +async def test_configure_token_provider(case: TokenManagerCase): source = MagicMock() if case.expect_tm_set else await _StubSourceValid.create("tok") data = _sc_data(short_name="src", oauth_type=case.oauth_type) ctx = _make_ctx() @@ -728,23 +726,24 @@ async def test_configure_token_manager(case: TokenManagerCase): auth_provider_instance=None, auth_mode=case.auth_mode, ) + service = _make_service() if case.expect_tm_set: - with patch("airweave.domains.sources.lifecycle.TokenManager") as mock_tm: - mock_tm.return_value = MagicMock() - await SourceLifecycleService._configure_token_manager( - db=MagicMock(), source=source, source_connection_data=data, + with patch("airweave.domains.sources.lifecycle.OAuthTokenProvider") as mock_tp: + mock_tp.return_value = MagicMock() + await service._configure_token_provider( + source=source, source_connection_data=data, source_credentials="tok", ctx=ctx, logger=ctx.logger, access_token=case.access_token, auth_config=auth_config, ) - source.set_token_manager.assert_called_once() + source.set_token_provider.assert_called_once() else: - await SourceLifecycleService._configure_token_manager( - db=MagicMock(), source=source, source_connection_data=data, + await service._configure_token_provider( + source=source, source_connection_data=data, source_credentials="tok", ctx=ctx, logger=ctx.logger, access_token=case.access_token, auth_config=auth_config, ) - assert source._token_manager is None + assert source._token_provider is None # =========================================================================== @@ -785,20 +784,6 @@ def test_noop_when_none(self): source.set_http_client_factory.assert_not_called() -class TestConfigureSyncIdentifiers: - def test_happy(self): - source = MagicMock() - ctx = _make_ctx() - SourceLifecycleService._configure_sync_identifiers(source, _sc_data(), ctx) - source.set_sync_identifiers.assert_called_once() - - def test_swallows_exception(self): - source = MagicMock() - source.set_sync_identifiers.side_effect = AttributeError - ctx = _make_ctx() - SourceLifecycleService._configure_sync_identifiers(source, _sc_data(), ctx) - - # =========================================================================== # _wrap_source_with_airweave_client() — table-driven # =========================================================================== diff --git a/backend/airweave/domains/sources/token_providers/__init__.py b/backend/airweave/domains/sources/token_providers/__init__.py new file mode 100644 index 000000000..73b43f9a0 --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/__init__.py @@ -0,0 +1,13 @@ +"""TokenProvider implementations.""" + +from airweave.domains.sources.token_providers.auth_provider import AuthProviderTokenProvider +from airweave.domains.sources.token_providers.oauth import OAuthTokenProvider +from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol +from airweave.domains.sources.token_providers.static import StaticTokenProvider + +__all__ = [ + "AuthProviderTokenProvider", + "OAuthTokenProvider", + "StaticTokenProvider", + "TokenProviderProtocol", +] diff --git a/backend/airweave/domains/sources/token_providers/auth_provider.py b/backend/airweave/domains/sources/token_providers/auth_provider.py new file mode 100644 index 000000000..0073f9c3f --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/auth_provider.py @@ -0,0 +1,172 @@ +"""AuthProviderTokenProvider — delegates to Pipedream / Composio. + +The auth provider is the source of truth for credentials. Every +``get_token()`` call fetches fresh credentials from the provider. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +from airweave.core.logging import ContextualLogger +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAccountNotFoundError, + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderServerError, +) +from airweave.domains.sources.token_providers.exceptions import ( + TokenCredentialsInvalidError, + TokenProviderAccountGoneError, + TokenProviderConfigError, + TokenProviderMissingCredsError, + TokenProviderRateLimitError, + TokenProviderServerError, +) +from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol + +if TYPE_CHECKING: + from airweave.domains.sources.protocols import SourceRegistryProtocol + +_PROVIDER_KIND = "auth_provider" + + +class AuthProviderTokenProvider(TokenProviderProtocol): + """TokenProvider backed by an external auth provider (Pipedream / Composio). + + In *direct* mode the auth provider holds the user's OAuth connection + and can vend fresh access tokens on demand. + """ + + def __init__( + self, + auth_provider_instance: BaseAuthProvider, + source_short_name: str, + source_registry: SourceRegistryProtocol, + *, + logger: ContextualLogger, + ): + """Initialize with an auth provider instance. + + Args: + auth_provider_instance: A ``BaseAuthProvider`` subclass instance. + source_short_name: Source identifier. + source_registry: Registry to look up runtime auth field names. + logger: Contextual logger with sync metadata. + """ + self._provider = auth_provider_instance + self._source_short_name = source_short_name + self._source_registry = source_registry + self._logger = logger + + async def _fetch_token(self) -> str: + """Call the auth provider and extract the access token. + + Retries up to 3 times on transient failures (5xx, rate limits) + before translating the final exception. + + Raises: + TokenCredentialsInvalidError: If the provider rejected our credentials. + TokenProviderAccountGoneError: If the connected account was deleted. + TokenProviderMissingCredsError: If the response lacks required fields. + TokenProviderConfigError: If the provider configuration is invalid. + TokenProviderRateLimitError: If the provider is throttling us. + TokenProviderServerError: If the provider is temporarily unavailable. + """ + entry = self._source_registry.get(self._source_short_name) + + try: + creds = await self._call_provider_with_retry(entry) + except AuthProviderAccountNotFoundError as e: + raise TokenProviderAccountGoneError( + f"Account deleted in auth provider for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + account_id=e.account_id, + ) from e + except AuthProviderAuthError as e: + raise TokenCredentialsInvalidError( + f"Auth provider credentials rejected for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + ) from e + except AuthProviderMissingFieldsError as e: + raise TokenProviderMissingCredsError( + f"Auth provider response missing fields for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + missing_fields=e.missing_fields, + ) from e + except AuthProviderConfigError as e: + raise TokenProviderConfigError( + f"Auth provider misconfigured for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + ) from e + except AuthProviderRateLimitError as e: + raise TokenProviderRateLimitError( + f"Auth provider rate-limited for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + retry_after=e.retry_after, + ) from e + except AuthProviderServerError as e: + raise TokenProviderServerError( + f"Auth provider server error for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + status_code=e.status_code, + ) from e + except Exception as e: + raise TokenProviderServerError( + f"Unexpected auth provider error for {self._source_short_name}: {e}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + ) from e + + if not isinstance(creds, dict) or "access_token" not in creds: + raise TokenProviderMissingCredsError( + f"No access_token in auth provider response for {self._source_short_name}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + missing_fields=["access_token"], + ) + + return creds["access_token"] + + @retry( + retry=retry_if_exception_type((AuthProviderRateLimitError, AuthProviderServerError)), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=30), + reraise=True, + ) + async def _call_provider_with_retry(self, entry) -> dict: + return await self._provider.get_creds_for_source( + source_short_name=self._source_short_name, + source_auth_config_fields=entry.runtime_auth_all_fields, + optional_fields=entry.runtime_auth_optional_fields, + ) + + async def get_token(self) -> str: + """Return a fresh token from the auth provider. + + Raises: + TokenProviderError: If the provider call fails (see _fetch_token). + """ + return await self._fetch_token() + + async def force_refresh(self) -> str: + """Force-refresh by re-calling the auth provider. + + Auth providers always return the latest token, so this is + identical to ``get_token()``. + + Raises: + TokenProviderError: If the provider call fails (see _fetch_token). + """ + return await self._fetch_token() diff --git a/backend/airweave/domains/sources/token_providers/exceptions.py b/backend/airweave/domains/sources/token_providers/exceptions.py new file mode 100644 index 000000000..896323db8 --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/exceptions.py @@ -0,0 +1,150 @@ +"""Token-provider exceptions — the vocabulary for credential-fetching failures. + +Every token provider (OAuth, AuthProvider, Static) translates its +upstream errors into these types so the source lifecycle can take +differentiated action without inspecting ``__cause__``. + +Hierarchy +--------- +TokenProviderError (SourceError) — base for all token-provider failures +├── TokenCredentialsInvalidError — token / refresh_token expired or revoked +├── TokenProviderAccountGoneError — external account record deleted (Composio / Pipedream) +├── TokenProviderConfigError — fundamental misconfiguration +├── TokenProviderMissingCredsError — response lacks required credential fields +├── TokenProviderRateLimitError — upstream rate-limiting +├── TokenProviderServerError — server error (5xx / timeout) +└── TokenRefreshNotSupportedError — static token / no refresh_token + +Every exception carries: + ``source_short_name`` — inherited from SourceError + ``provider_kind`` — "oauth" | "auth_provider" | "static" +""" + +from typing import Optional + +from airweave.domains.sources.exceptions import SourceError + + +class TokenProviderError(SourceError): + """Base for all token-provider runtime failures. + + Sits directly under ``SourceError`` so the lifecycle can catch it + independently of the older ``SourceAuthError`` / ``SourceTokenRefreshError`` + hierarchy. + """ + + def __init__( + self, + message: str, + *, + source_short_name: str = "", + provider_kind: str = "", + ): + """Initialize TokenProviderError.""" + self.provider_kind = provider_kind + super().__init__(message, source_short_name=source_short_name) + + +class TokenCredentialsInvalidError(TokenProviderError): + """Token or refresh_token is expired, revoked, or otherwise invalid. + + The user needs to re-authenticate (OAuth) or re-link their account + (auth provider). + """ + + pass + + +class TokenProviderAccountGoneError(TokenProviderError): + """External account record was deleted from the auth provider. + + The Composio/Pipedream connected-account no longer exists. + The user needs to re-create the connection in the provider. + """ + + def __init__( + self, + message: str, + *, + source_short_name: str = "", + provider_kind: str = "", + account_id: str = "", + ): + """Initialize TokenProviderAccountGoneError.""" + self.account_id = account_id + super().__init__(message, source_short_name=source_short_name, provider_kind=provider_kind) + + +class TokenProviderConfigError(TokenProviderError): + """Fundamental misconfiguration — retrying will never fix this. + + Examples: wrong app in the auth provider, missing template variables, + credential record missing from DB. + """ + + pass + + +class TokenProviderMissingCredsError(TokenProviderError): + """Response from the credential source lacks required fields. + + The account exists and responded, but the credential dict is + incomplete (e.g. no ``access_token``). + """ + + def __init__( + self, + message: str, + *, + source_short_name: str = "", + provider_kind: str = "", + missing_fields: Optional[list[str]] = None, + ): + """Initialize TokenProviderMissingCredsError.""" + self.missing_fields = missing_fields or [] + super().__init__(message, source_short_name=source_short_name, provider_kind=provider_kind) + + +class TokenProviderRateLimitError(TokenProviderError): + """Upstream credential source is rate-limiting us. + + The lifecycle / retry layer should wait ``retry_after`` seconds. + """ + + def __init__( + self, + message: str = "Token provider rate-limited", + *, + source_short_name: str = "", + provider_kind: str = "", + retry_after: float = 30.0, + ): + """Initialize TokenProviderRateLimitError.""" + self.retry_after = retry_after + super().__init__(message, source_short_name=source_short_name, provider_kind=provider_kind) + + +class TokenProviderServerError(TokenProviderError): + """The credential source returned a server error (5xx, timeout, connection error).""" + + def __init__( + self, + message: str = "Token provider returned a server error", + *, + source_short_name: str = "", + provider_kind: str = "", + status_code: Optional[int] = None, + ): + """Initialize TokenProviderServerError.""" + self.status_code = status_code + super().__init__(message, source_short_name=source_short_name, provider_kind=provider_kind) + + +class TokenRefreshNotSupportedError(TokenProviderError): + """Token refresh is not supported by this provider. + + Raised by ``StaticTokenProvider.force_refresh()`` and by + ``OAuthTokenProvider.force_refresh()`` when no refresh_token exists. + """ + + pass diff --git a/backend/airweave/domains/sources/token_providers/oauth.py b/backend/airweave/domains/sources/token_providers/oauth.py new file mode 100644 index 000000000..bed0ec85b --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/oauth.py @@ -0,0 +1,270 @@ +"""OAuthTokenProvider — proactive token refresh for OAuth2 sources. + +Thin wrapper: timer + lock + cache. The actual refresh is delegated to +``oauth2_service.refresh_and_persist()``. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Optional, Union +from uuid import UUID + +from pydantic import BaseModel +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +from airweave.core.logging import ContextualLogger +from airweave.db.session import get_db_context +from airweave.domains.oauth.exceptions import ( + OAuthRefreshBadRequestError, + OAuthRefreshCredentialMissingError, + OAuthRefreshRateLimitError, + OAuthRefreshServerError, + OAuthRefreshTokenRevokedError, +) +from airweave.domains.oauth.types import RefreshResult +from airweave.domains.sources.token_providers.exceptions import ( + TokenCredentialsInvalidError, + TokenProviderConfigError, + TokenProviderError, + TokenProviderRateLimitError, + TokenProviderServerError, + TokenRefreshNotSupportedError, +) +from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol + +if TYPE_CHECKING: + from airweave.api.context import ApiContext + from airweave.domains.oauth.protocols import OAuth2ServiceProtocol + +_REFRESHABLE_OAUTH_TYPES = frozenset({"with_refresh", "with_rotating_refresh"}) +_DEFAULT_REFRESH_INTERVAL_SECONDS = 25 * 60 +_REFRESH_LIFETIME_FRACTION = 0.80 +_MIN_REFRESH_INTERVAL_SECONDS = 60 +_MAX_REFRESH_INTERVAL_SECONDS = 50 * 60 +_PROVIDER_KIND = "oauth" + + +class OAuthTokenProvider(TokenProviderProtocol): + """TokenProvider backed by OAuth2 credentials. + + Accepts raw credentials and determines refresh capability internally: + - If oauth_type supports refresh AND a refresh_token is present, + proactively refreshes before expiry. + - Otherwise serves the initial access_token as a static token. + + """ + + def __init__( + self, + credentials: Union[str, dict, BaseModel], + *, + oauth_type: Optional[str], + oauth2_service: OAuth2ServiceProtocol, + source_short_name: str, + connection_id: UUID, + ctx: ApiContext, + logger: ContextualLogger, + config_fields: Optional[dict] = None, + ): + """Initialize the OAuth token provider. + + Args: + credentials: Raw credentials (str token, dict, or Pydantic model). + oauth_type: OAuth type from the source connection (e.g. "with_refresh"). + oauth2_service: Service that handles the actual refresh + persistence. + source_short_name: Source identifier. + connection_id: Connection UUID (passed to oauth2_service for refresh). + ctx: API context (passed to oauth2_service for refresh). + logger: Contextual logger with sync metadata. + config_fields: Optional config fields for templated backend URLs. + + Raises: + ValueError: If no access token can be extracted from credentials. + """ + token = _extract_access_token(credentials) + if not token: + raise ValueError(f"No access token found in credentials for {source_short_name}") + + self._token = token + self._oauth2_service = oauth2_service + self._source_short_name = source_short_name + self._connection_id = connection_id + self._ctx = ctx + self._logger = logger + self._config_fields = config_fields + self._can_refresh = oauth_type in _REFRESHABLE_OAUTH_TYPES and _has_refresh_token( + credentials + ) + self._needs_initial_refresh = True + self._expires_at: float = 0.0 + self._lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # TokenProvider protocol + # ------------------------------------------------------------------ + + async def get_token(self) -> str: + """Return a valid access token, refreshing proactively if stale. + + Raises: + TokenProviderError: On refresh failure (see _refresh_and_translate). + """ + if not self._can_refresh: + return self._token + + if not self._needs_initial_refresh and time.monotonic() < self._expires_at: + return self._token + + async with self._lock: + if not self._needs_initial_refresh and time.monotonic() < self._expires_at: + return self._token + + result = await self._refresh() + self._apply_refresh(result) + return self._token + + async def force_refresh(self) -> str: + """Force an immediate token refresh (e.g. after a 401). + + Raises: + TokenRefreshNotSupportedError: If refresh is not possible. + TokenProviderError: On refresh failure (see _translate_refresh_error). + """ + if not self._can_refresh: + raise TokenRefreshNotSupportedError( + f"Token refresh not supported for {self._source_short_name}", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + ) + + async with self._lock: + self._logger.warning(f"Forcing token refresh for {self._source_short_name} due to 401") + result = await self._refresh() + self._apply_refresh(result) + return self._token + + # ------------------------------------------------------------------ + # Private + # ------------------------------------------------------------------ + + def _apply_refresh(self, result: RefreshResult) -> None: + """Update token and schedule next refresh based on expires_in.""" + self._token = result.access_token + self._needs_initial_refresh = False + + interval = self._compute_refresh_interval(result.expires_in) + self._expires_at = time.monotonic() + interval + + if result.expires_in is not None: + self._logger.debug( + f"Token for {self._source_short_name} expires in {result.expires_in}s, " + f"next refresh in {interval:.0f}s" + ) + + @staticmethod + def _compute_refresh_interval(expires_in: Optional[int]) -> float: + """Derive refresh interval from provider-reported expires_in. + + Uses 80% of the reported lifetime, clamped to [60s, 50min]. + Falls back to the default 25-min interval when expires_in is unavailable. + """ + if expires_in is None or expires_in <= 0: + return _DEFAULT_REFRESH_INTERVAL_SECONDS + interval = expires_in * _REFRESH_LIFETIME_FRACTION + return max(_MIN_REFRESH_INTERVAL_SECONDS, min(interval, _MAX_REFRESH_INTERVAL_SECONDS)) + + async def _refresh(self) -> RefreshResult: + """Refresh the token via oauth2_service, translating errors to TokenProvider types. + + Retries up to 3 times on transient failures (5xx, rate limits) + before translating the final exception. + """ + try: + return await self._refresh_with_retry() + except Exception as e: + raise self._translate_refresh_error(e) from e + + @retry( + retry=retry_if_exception_type((OAuthRefreshServerError, OAuthRefreshRateLimitError)), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=30), + reraise=True, + ) + async def _refresh_with_retry(self) -> RefreshResult: + async with get_db_context() as db: + return await self._oauth2_service.refresh_and_persist( + db=db, + integration_short_name=self._source_short_name, + connection_id=self._connection_id, + ctx=self._ctx, + config_fields=self._config_fields, + ) + + def _translate_refresh_error(self, exc: Exception) -> TokenProviderError: + """Map an OAuth refresh exception to the corresponding TokenProvider exception.""" + sn = self._source_short_name + + if isinstance(exc, (OAuthRefreshTokenRevokedError, OAuthRefreshBadRequestError)): + return TokenCredentialsInvalidError( + f"OAuth credentials invalid for {sn}: {exc}", + source_short_name=sn, + provider_kind=_PROVIDER_KIND, + ) + + if isinstance(exc, OAuthRefreshCredentialMissingError): + return TokenProviderConfigError( + f"Credential missing for {sn}: {exc}", + source_short_name=sn, + provider_kind=_PROVIDER_KIND, + ) + + if isinstance(exc, OAuthRefreshRateLimitError): + return TokenProviderRateLimitError( + f"OAuth rate-limited for {sn}: {exc}", + source_short_name=sn, + provider_kind=_PROVIDER_KIND, + retry_after=exc.retry_after, + ) + + if isinstance(exc, OAuthRefreshServerError): + return TokenProviderServerError( + f"OAuth server error for {sn}: {exc}", + source_short_name=sn, + provider_kind=_PROVIDER_KIND, + status_code=exc.status_code, + ) + + return TokenProviderServerError( + f"Unexpected OAuth error for {sn}: {exc}", + source_short_name=sn, + provider_kind=_PROVIDER_KIND, + ) + + +# --------------------------------------------------------------------------- +# Module-private credential helpers +# --------------------------------------------------------------------------- + + +def _extract_access_token(creds: Union[str, dict, object]) -> Optional[str]: + """Extract access token from credentials (str, dict, or object).""" + if isinstance(creds, str): + return creds + if isinstance(creds, dict): + return creds.get("access_token") + if hasattr(creds, "access_token"): + return creds.access_token + return None + + +def _has_refresh_token(creds: Union[str, dict, object]) -> bool: + """Check if credentials contain a non-empty refresh token.""" + if isinstance(creds, dict): + rt = creds.get("refresh_token") + return bool(rt and str(rt).strip()) + if hasattr(creds, "refresh_token"): + rt = creds.refresh_token + return bool(rt and str(rt).strip()) + return False diff --git a/backend/airweave/domains/sources/token_providers/protocol.py b/backend/airweave/domains/sources/token_providers/protocol.py new file mode 100644 index 000000000..84f2edbb7 --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/protocol.py @@ -0,0 +1,41 @@ +"""TokenProvider protocol — the contract sources use to obtain auth tokens.""" + +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class TokenProviderProtocol(Protocol): + """Provides auth tokens to sources. + + Sources call ``get_token()`` to obtain a valid token for building + their own auth headers. When a 401 is received, sources call + ``force_refresh()`` to get a fresh token after an explicit refresh. + + Implementations: + - ``OAuthTokenProvider`` — proactive refresh, DB-backed credential store + - ``StaticTokenProvider`` — raw string (API keys, PATs, validation) + - ``AuthProviderTokenProvider`` — delegates to Pipedream / Composio + + All implementations raise exceptions from + ``domains.sources.token_providers.exceptions``: + - ``TokenCredentialsInvalidError`` — expired / revoked credentials + - ``TokenProviderAccountGoneError`` — external account deleted + - ``TokenProviderConfigError`` — fundamental misconfiguration + - ``TokenProviderMissingCredsError`` — response missing required fields + - ``TokenProviderRateLimitError`` — upstream rate-limiting + - ``TokenProviderServerError`` — server error (5xx / timeout) + - ``TokenRefreshNotSupportedError`` — static / no refresh_token + """ + + async def get_token(self) -> str: + """Return a valid token, refreshing proactively if stale.""" + ... + + async def force_refresh(self) -> str: + """Force an immediate token refresh (e.g. after a 401). + + Raises: + TokenRefreshNotSupportedError: If refresh is not supported. + TokenProviderError: If refresh fails. + """ + ... diff --git a/backend/airweave/domains/sources/token_providers/static.py b/backend/airweave/domains/sources/token_providers/static.py new file mode 100644 index 000000000..24c52a1ad --- /dev/null +++ b/backend/airweave/domains/sources/token_providers/static.py @@ -0,0 +1,42 @@ +"""StaticTokenProvider — holds a fixed token string. + +Used for API keys, personal access tokens, direct token injection, +and OAuth callback validation where no refresh is possible. +""" + +from airweave.domains.sources.token_providers.exceptions import TokenRefreshNotSupportedError +from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol + +_PROVIDER_KIND = "static" + + +class StaticTokenProvider(TokenProviderProtocol): + """TokenProvider backed by a single immutable token string.""" + + def __init__(self, token: str, *, source_short_name: str = ""): + """Initialize with a raw token. + + Args: + token: The static token value. + source_short_name: Source identifier (for error context). + """ + if not token: + raise ValueError("StaticTokenProvider requires a non-empty token") + self._token = token + self._source_short_name = source_short_name + + async def get_token(self) -> str: + """Return the static token.""" + return self._token + + async def force_refresh(self) -> str: + """Always raises — static tokens cannot be refreshed. + + Raises: + TokenRefreshNotSupportedError: Refresh is not supported for static tokens. + """ + raise TokenRefreshNotSupportedError( + "Token refresh not supported (static token)", + source_short_name=self._source_short_name, + provider_kind=_PROVIDER_KIND, + ) diff --git a/backend/airweave/domains/sources/types.py b/backend/airweave/domains/sources/types.py index 4c9c8cebc..8eaf4e743 100644 --- a/backend/airweave/domains/sources/types.py +++ b/backend/airweave/domains/sources/types.py @@ -9,8 +9,8 @@ from airweave.core.protocols.registry import BaseRegistryEntry from airweave.models.connection import Connection from airweave.models.source_connection import SourceConnection -from airweave.platform.auth_providers._base import BaseAuthProvider -from airweave.platform.auth_providers.auth_result import AuthProviderMode +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.auth_result import AuthProviderMode from airweave.platform.configs._base import BaseConfig, Fields diff --git a/backend/airweave/platform/auth/__init__.py b/backend/airweave/platform/auth/__init__.py index e69de29bb..c75d25671 100644 --- a/backend/airweave/platform/auth/__init__.py +++ b/backend/airweave/platform/auth/__init__.py @@ -0,0 +1,4 @@ +"""OAuth integration settings and schemas (not source-specific). + +Token providers have moved to ``airweave.domains.sources.token_providers``. +""" diff --git a/backend/airweave/platform/sources/_base.py b/backend/airweave/platform/sources/_base.py index a3095d134..a202b60c1 100644 --- a/backend/airweave/platform/sources/_base.py +++ b/backend/airweave/platform/sources/_base.py @@ -20,6 +20,7 @@ ) if TYPE_CHECKING: + from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol from airweave.platform.access_control.schemas import MembershipTuple import httpx @@ -62,13 +63,10 @@ class BaseSource: def __init__(self): """Initialize the base source.""" - self._logger: Optional[Any] = None # Store contextual logger as instance variable - self._token_manager: Optional[Any] = None # Store token manager for OAuth sources - self._http_client_factory: Optional[Callable] = None # Factory for creating HTTP clients - self._file_downloader: Optional[Any] = None # File download service - # Optional sync identifiers for multi-tenant scoped helpers - self._organization_id: Optional[str] = None - self._source_connection_id: Optional[str] = None + self._logger: Optional[Any] = None + self._token_provider: Optional["TokenProviderProtocol"] = None + self._http_client_factory: Optional[Callable] = None + self._file_downloader: Optional[Any] = None @property def logger(self): @@ -82,27 +80,18 @@ def set_logger(self, logger) -> None: """Set a contextual logger for this source.""" self._logger = logger - def set_sync_identifiers(self, organization_id: str, source_connection_id: str) -> None: - """Set sync-scoped identifiers for this source instance. - - These identifiers can be used by sources to persist auxiliary metadata - (e.g., schema catalogs) scoped to the current tenant/connection. - """ - self._organization_id = organization_id - self._source_connection_id = source_connection_id - @property - def token_manager(self): - """Get the token manager for this source.""" - return self._token_manager + def token_provider(self) -> Optional["TokenProviderProtocol"]: + """Get the token provider for this source.""" + return self._token_provider - def set_token_manager(self, token_manager) -> None: - """Set a token manager for this source. + def set_token_provider(self, provider: "TokenProviderProtocol") -> None: + """Set token provider for this source. Args: - token_manager: TokenManager instance for handling OAuth token refresh + provider: Any TokenProviderProtocol implementation. """ - self._token_manager = token_manager + self._token_provider = provider def set_http_client_factory(self, factory: Optional[Callable]) -> None: """Set the HTTP client factory for creating HTTP clients. @@ -221,44 +210,61 @@ def does_require_byoc(cls) -> bool: """Check if source requires user to bring their own OAuth client credentials.""" return cls.requires_byoc - async def get_access_token(self) -> Optional[str]: - """Get a valid access token using the token manager. + async def get_access_token(self) -> str: + """Get a valid access token, preferring the token provider when available. + + Falls back to self.access_token (set by create()) when no provider + is configured — e.g. during lightweight validation flows. Returns: - A valid access token if token manager is set and source uses OAuth, - None otherwise + A valid access token string. + + Raises: + RuntimeError: If neither a token provider nor self.access_token is available. """ - if self._token_manager: - return await self._token_manager.get_valid_token() + if self._token_provider: + return await self._token_provider.get_token() - # Fallback to instance access_token if no token manager - return getattr(self, "access_token", None) + token = getattr(self, "access_token", None) + if token: + return token + + raise RuntimeError( + f"{self.__class__.__name__}.get_access_token() called but no " + f"token provider or access_token is available." + ) async def refresh_on_unauthorized(self) -> Optional[str]: - """Refresh token after receiving a 401 error. + """Force-refresh the token after a 401 error. Returns: - New access token if refresh was successful, None otherwise + A fresh access token, or None if no token provider is set. """ - if self._token_manager: - return await self._token_manager.refresh_on_unauthorized() - return None + if not self._token_provider: + raise RuntimeError( + f"{self.__class__.__name__}.refresh_on_unauthorized() called but no " + f"token provider is configured. Ensure the lifecycle service " + f"sets a TokenProvider before calling this method." + ) + if self._token_provider: + return await self._token_provider.force_refresh() async def get_token_for_resource(self, resource_scope: str) -> Optional[str]: - """Get a token for a different resource scope via the token manager. + """Get a token for a different resource scope. Used for cross-resource access, e.g. SharePoint REST API when the - primary token is scoped to Microsoft Graph. + primary token is scoped to Microsoft Graph. Only works with + OAuthTokenProvider which has the ``get_token_for_resource`` method. Args: - resource_scope: The target scope, e.g. "https://tenant.sharepoint.com/.default" + resource_scope: The target scope. Returns: - An access token scoped to the requested resource, or None if unavailable. + An access token scoped to the requested resource, or None. """ - if not self._token_manager: - return None - return await self._token_manager.get_token_for_resource(resource_scope) + if self._token_provider and hasattr(self._token_provider, "get_token_for_resource"): + return await self._token_provider.get_token_for_resource(resource_scope) + return None @classmethod @abstractmethod @@ -449,7 +455,7 @@ async def _validate_oauth2( # noqa: C901 - or `ping_url` for a simple authorized GET using the access token, - or both (introspection first, then ping). - Token refresh is attempted automatically on 401 via `token_manager`. + Token refresh is attempted automatically on 401 via `token_provider`. Returns: True if the token is active and the endpoint(s) respond as expected; otherwise False. diff --git a/backend/airweave/platform/sources/airtable.py b/backend/airweave/platform/sources/airtable.py index d25516926..83a58d400 100644 --- a/backend/airweave/platform/sources/airtable.py +++ b/backend/airweave/platform/sources/airtable.py @@ -122,10 +122,9 @@ async def _get_with_auth( self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: - # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token diff --git a/backend/airweave/platform/sources/asana.py b/backend/airweave/platform/sources/asana.py index ed3231869..47cdd61f1 100644 --- a/backend/airweave/platform/sources/asana.py +++ b/backend/airweave/platform/sources/asana.py @@ -112,10 +112,9 @@ async def _get_with_auth( self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: - # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token diff --git a/backend/airweave/platform/sources/box.py b/backend/airweave/platform/sources/box.py index 08f76b03d..5388303d3 100644 --- a/backend/airweave/platform/sources/box.py +++ b/backend/airweave/platform/sources/box.py @@ -151,10 +151,9 @@ async def _get_with_auth( self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: - # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = { "Authorization": f"Bearer {new_token}", "Accept": "application/json", diff --git a/backend/airweave/platform/sources/clickup.py b/backend/airweave/platform/sources/clickup.py index a52e941ea..883851646 100644 --- a/backend/airweave/platform/sources/clickup.py +++ b/backend/airweave/platform/sources/clickup.py @@ -137,10 +137,10 @@ async def _get_with_auth( self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token diff --git a/backend/airweave/platform/sources/confluence.py b/backend/airweave/platform/sources/confluence.py index f0153daaf..6a5e7910a 100644 --- a/backend/airweave/platform/sources/confluence.py +++ b/backend/airweave/platform/sources/confluence.py @@ -142,12 +142,12 @@ async def _get_with_auth(self, client: httpx.AsyncClient, url: str) -> Any: return response.json() except httpx.HTTPStatusError as e: # Handle 401 Unauthorized - try refreshing token - if e.response.status_code == 401 and self._token_manager: + if e.response.status_code == 401 and self._token_provider: self.logger.warning( "🔐 Received 401 Unauthorized from Confluence - attempting token refresh" ) try: - refreshed = await self._token_manager.refresh_on_unauthorized() + refreshed = await self._token_provider.force_refresh() if refreshed: # Retry with new token (the retry decorator will handle this) diff --git a/backend/airweave/platform/sources/dropbox.py b/backend/airweave/platform/sources/dropbox.py index 30e26c4ce..0f08a6417 100644 --- a/backend/airweave/platform/sources/dropbox.py +++ b/backend/airweave/platform/sources/dropbox.py @@ -108,9 +108,9 @@ async def _post_with_auth( except httpx.HTTPStatusError as e: # Handle 401 Unauthorized - try refreshing token - if e.response.status_code == 401 and self._token_manager: + if e.response.status_code == 401 and self._token_provider: self.logger.debug("Received 401 error, attempting to refresh token") - refreshed = await self._token_manager.refresh_on_unauthorized() + refreshed = await self._token_provider.force_refresh() if refreshed: # Retry with new token (the retry decorator will handle this) diff --git a/backend/airweave/platform/sources/gitlab.py b/backend/airweave/platform/sources/gitlab.py index 8dfea78fe..33d743bbd 100644 --- a/backend/airweave/platform/sources/gitlab.py +++ b/backend/airweave/platform/sources/gitlab.py @@ -139,12 +139,12 @@ async def _get_with_auth( if response.status_code == 401: self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") - if self.token_manager: + if self.token_provider: try: # Force refresh the token from airweave.core.exceptions import TokenRefreshError - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry with new token @@ -205,12 +205,12 @@ async def _get_paginated_results( if response.status_code == 401: self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") - if self.token_manager: + if self.token_provider: try: # Force refresh the token from airweave.core.exceptions import TokenRefreshError - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = { "Authorization": f"Bearer {new_token}", "Accept": "application/json", diff --git a/backend/airweave/platform/sources/google_drive.py b/backend/airweave/platform/sources/google_drive.py index 4855f0d7b..08329b958 100644 --- a/backend/airweave/platform/sources/google_drive.py +++ b/backend/airweave/platform/sources/google_drive.py @@ -142,10 +142,10 @@ async def _get_with_auth( # noqa: C901 self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token diff --git a/backend/airweave/platform/sources/intercom.py b/backend/airweave/platform/sources/intercom.py index e6d07dc38..e3d5786d9 100644 --- a/backend/airweave/platform/sources/intercom.py +++ b/backend/airweave/platform/sources/intercom.py @@ -130,20 +130,9 @@ async def create( ) return instance - def _headers(self) -> Dict[str, str]: - """Return auth and version headers for Intercom API.""" - return { - "Authorization": f"Bearer {getattr(self, 'access_token', '')}", - "Accept": "application/json", - "Content-Type": "application/json", - "Intercom-Version": INTERCOM_VERSION, - } - async def _get_auth_headers(self) -> Dict[str, str]: - """Get OAuth headers (with token from token_manager if set).""" + """Get OAuth headers via the token provider.""" token = await self.get_access_token() - if not token: - raise ValueError("No access token available for authentication") return { "Authorization": f"Bearer {token}", "Accept": "application/json", @@ -166,10 +155,10 @@ async def _get_with_auth( headers = await self._get_auth_headers() try: response = await client.get(url, headers=headers, params=params, timeout=30.0) - if response.status_code == 401 and self.token_manager: + if response.status_code == 401 and self.token_provider: self.logger.warning(f"Received 401 for {url}, refreshing token...") try: - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers["Authorization"] = f"Bearer {new_token}" response = await client.get(url, headers=headers, params=params, timeout=30.0) except TokenRefreshError as e: @@ -194,10 +183,10 @@ async def _post_with_auth( headers = await self._get_auth_headers() try: response = await client.post(url, headers=headers, json=json_body, timeout=30.0) - if response.status_code == 401 and self.token_manager: + if response.status_code == 401 and self.token_provider: self.logger.warning(f"Received 401 for {url}, refreshing token...") try: - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers["Authorization"] = f"Bearer {new_token}" response = await client.post(url, headers=headers, json=json_body, timeout=30.0) except TokenRefreshError as e: diff --git a/backend/airweave/platform/sources/jira.py b/backend/airweave/platform/sources/jira.py index 17fceaf4d..ea4c24d36 100644 --- a/backend/airweave/platform/sources/jira.py +++ b/backend/airweave/platform/sources/jira.py @@ -162,12 +162,12 @@ async def _get_with_auth(self, client: httpx.AsyncClient, url: str) -> Any: return data except httpx.HTTPStatusError as e: # Handle 401 Unauthorized - try refreshing token - if e.response.status_code == 401 and self._token_manager: + if e.response.status_code == 401 and self._token_provider: self.logger.warning( "🔐 Received 401 Unauthorized from Jira - attempting token refresh" ) try: - refreshed = await self._token_manager.refresh_on_unauthorized() + refreshed = await self._token_provider.force_refresh() if refreshed: # Retry with new token (the retry decorator will handle this) @@ -217,9 +217,9 @@ async def _post_with_auth( response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: - if e.response.status_code == 401 and self._token_manager: + if e.response.status_code == 401 and self._token_provider: self.logger.info("Received 401 error, attempting to refresh token") - refreshed = await self._token_manager.refresh_on_unauthorized() + refreshed = await self._token_provider.force_refresh() if refreshed: self.logger.info("Token refreshed, retrying request") raise @@ -471,7 +471,7 @@ async def _generate_issue_entities( self.logger.info(f"Completed fetching all issues for project {project_key}") break - async def generate_entities(self) -> AsyncGenerator[BaseEntity, None]: + async def generate_entities(self) -> AsyncGenerator[BaseEntity, None]: # noqa: C901 """Generate all entities from Jira and optionally Zephyr Scale.""" self.logger.info("Starting Jira entity generation process") diff --git a/backend/airweave/platform/sources/linear.py b/backend/airweave/platform/sources/linear.py index c561fd12b..cbec2440a 100644 --- a/backend/airweave/platform/sources/linear.py +++ b/backend/airweave/platform/sources/linear.py @@ -162,11 +162,12 @@ async def _post_with_auth(self, client: httpx.AsyncClient, query: str) -> Dict: self._stats["api_calls"] += 1 try: + access_token = await self.get_access_token() response = await client.post( "https://api.linear.app/graphql", headers={ "Content-Type": "application/json", - "Authorization": f"Bearer {self.access_token}", + "Authorization": f"Bearer {access_token}", }, json={"query": query}, ) diff --git a/backend/airweave/platform/sources/monday.py b/backend/airweave/platform/sources/monday.py index 149ebfcd5..9fbac8c22 100644 --- a/backend/airweave/platform/sources/monday.py +++ b/backend/airweave/platform/sources/monday.py @@ -90,8 +90,9 @@ async def _graphql_query( self, client: httpx.AsyncClient, query: str, variables: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Execute a single GraphQL query against the Monday.com API.""" + access_token = await self.get_access_token() headers = { - "Authorization": self.access_token, + "Authorization": access_token, "Content-Type": "application/json", } payload = {"query": query} diff --git a/backend/airweave/platform/sources/notion.py b/backend/airweave/platform/sources/notion.py index b6c584d68..b412d65f6 100644 --- a/backend/airweave/platform/sources/notion.py +++ b/backend/airweave/platform/sources/notion.py @@ -182,8 +182,9 @@ async def _get_with_auth(self, client: httpx.AsyncClient, url: str) -> dict: self.logger.debug(f"GET request to {url}") self._stats["api_calls"] += 1 + access_token = await self.get_access_token() headers = { - "Authorization": f"Bearer {self.access_token}", + "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } @@ -244,8 +245,9 @@ async def _post_with_auth(self, client: httpx.AsyncClient, url: str, json_data: self.logger.debug(f"POST request to {url}") self._stats["api_calls"] += 1 + access_token = await self.get_access_token() headers = { - "Authorization": f"Bearer {self.access_token}", + "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } diff --git a/backend/airweave/platform/sources/salesforce.py b/backend/airweave/platform/sources/salesforce.py index 61d95843c..b45a4cd8b 100644 --- a/backend/airweave/platform/sources/salesforce.py +++ b/backend/airweave/platform/sources/salesforce.py @@ -535,7 +535,8 @@ async def validate(self) -> bool: # Just validate that we have an access token return bool(getattr(self, "access_token", None)) - if not getattr(self, "access_token", None): + access_token = await self.get_access_token() + if not access_token: self.logger.error("Salesforce validation failed: missing access token.") return False @@ -544,11 +545,9 @@ async def validate(self) -> bool: return False try: - # Use the OAuth2 validation helper with Salesforce's identity endpoint - # instance_url is normalized (no protocol), so we need to add https:// return await self._validate_oauth2( ping_url=f"https://{self.instance_url}/services/oauth2/userinfo", - access_token=self.access_token, + access_token=access_token, timeout=10.0, ) except Exception as e: diff --git a/backend/airweave/platform/sources/shopify.py b/backend/airweave/platform/sources/shopify.py index 94cbb03f7..46968ba66 100644 --- a/backend/airweave/platform/sources/shopify.py +++ b/backend/airweave/platform/sources/shopify.py @@ -214,11 +214,12 @@ def _build_admin_url(self, resource: str, resource_id: str) -> str: """ return f"https://{self.shop_domain}/admin/{resource}/{resource_id}" - def _get_headers(self) -> Dict[str, str]: + async def _get_headers(self) -> Dict[str, str]: """Get headers for authenticated API requests.""" + access_token = await self.get_access_token() return { "Content-Type": "application/json", - "X-Shopify-Access-Token": self.access_token, + "X-Shopify-Access-Token": access_token, } @retry( @@ -237,7 +238,8 @@ async def _get_with_auth(self, client: httpx.AsyncClient, url: str) -> Dict: Returns: JSON response from API """ - response = await client.get(url, headers=self._get_headers(), timeout=30.0) + headers = await self._get_headers() + response = await client.get(url, headers=headers, timeout=30.0) response.raise_for_status() return response.json() @@ -259,7 +261,8 @@ async def _get_with_retry(self, client: httpx.AsyncClient, url: str) -> httpx.Re Returns: Full httpx.Response object (for header access) """ - response = await client.get(url, headers=self._get_headers(), timeout=30.0) + headers = await self._get_headers() + response = await client.get(url, headers=headers, timeout=30.0) response.raise_for_status() return response @@ -1141,9 +1144,10 @@ async def _generate_file_entities( variables["after"] = cursor try: + headers = await self._get_headers() response = await client.post( graphql_url, - headers=self._get_headers(), + headers=headers, json={"query": query, "variables": variables}, timeout=30.0, ) @@ -1324,8 +1328,8 @@ async def validate(self) -> bool: return False try: - # Get access token if not already obtained - if not self.access_token: + # Ensure we have an access token (exchange credentials if needed) + if not await self.get_access_token(): self.access_token = await self._get_access_token() async with self.http_client(timeout=10.0) as client: diff --git a/backend/airweave/platform/sources/slack.py b/backend/airweave/platform/sources/slack.py index 1daee7a35..659d4a558 100644 --- a/backend/airweave/platform/sources/slack.py +++ b/backend/airweave/platform/sources/slack.py @@ -94,10 +94,10 @@ async def _get_with_auth( self.logger.warning(f"Received 401 Unauthorized for {url}, refreshing token...") # If we have a token manager, try to refresh - if self.token_manager: + if self.token_provider: try: # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token diff --git a/backend/airweave/platform/sources/todoist.py b/backend/airweave/platform/sources/todoist.py index 27580ff71..cc93a44eb 100644 --- a/backend/airweave/platform/sources/todoist.py +++ b/backend/airweave/platform/sources/todoist.py @@ -69,7 +69,8 @@ async def _get_with_auth( Returns the JSON response (dict or list). If a 404 error is encountered, returns None instead of raising an exception. """ - headers = {"Authorization": f"Bearer {self.access_token}"} + access_token = await self.get_access_token() + headers = {"Authorization": f"Bearer {access_token}"} try: response = await client.get(url, headers=headers, params=params) response.raise_for_status() diff --git a/backend/airweave/platform/sources/zendesk.py b/backend/airweave/platform/sources/zendesk.py index ab7451d08..58cc73da6 100644 --- a/backend/airweave/platform/sources/zendesk.py +++ b/backend/airweave/platform/sources/zendesk.py @@ -109,10 +109,10 @@ async def _get_with_auth( if response.status_code == 401: self.logger.warning(f"Received 401 Unauthorized for {url}") - if self.token_manager: + if self.token_provider: try: # Force refresh the token - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers = {"Authorization": f"Bearer {new_token}"} # Retry the request with the new token @@ -140,19 +140,9 @@ async def _get_with_auth( async def _get_auth_headers(self) -> Dict[str, str]: """Get OAuth authentication headers.""" - # Use get_access_token method to avoid sending 'Bearer None' token = await self.get_access_token() - if not token: - raise ValueError("No access token available for authentication") return {"Authorization": f"Bearer {token}"} - async def get_access_token(self) -> Optional[str]: - """Get the current access token.""" - if self.token_manager: - # Token manager handles token retrieval - return getattr(self, "access_token", None) - return getattr(self, "access_token", None) - @staticmethod def _parse_datetime(value: Optional[str]) -> Optional[datetime]: """Parse Zendesk ISO8601 timestamps into timezone-aware datetimes.""" diff --git a/backend/airweave/platform/sources/zoho_crm.py b/backend/airweave/platform/sources/zoho_crm.py index ab8079d31..ba9114072 100644 --- a/backend/airweave/platform/sources/zoho_crm.py +++ b/backend/airweave/platform/sources/zoho_crm.py @@ -1208,18 +1208,9 @@ async def validate(self) -> bool: Note: Zoho uses 'Zoho-oauthtoken' header format, not standard 'Bearer'. """ - # DEBUG: Log token state - self.logger.info( - f"🔍 validate() - self.access_token exists: {bool(getattr(self, 'access_token', None))}" - ) - if getattr(self, "access_token", None): - self.logger.info( - f"🔍 validate() - self.access_token preview: {self.access_token[:20]}..." - ) - token = await self.get_access_token() self.logger.info( - (f"🔍 validate() - get_access_token() returned: {token[:20] if token else 'None'}...") + f"🔍 validate() - get_access_token() returned: {token[:20] if token else 'None'}..." ) if not token: diff --git a/backend/airweave/platform/sources/zoom.py b/backend/airweave/platform/sources/zoom.py index f7b8fe5f0..daddc175c 100644 --- a/backend/airweave/platform/sources/zoom.py +++ b/backend/airweave/platform/sources/zoom.py @@ -113,9 +113,9 @@ async def _get_with_auth( self.logger.warning( f"Got 401 Unauthorized from Zoom API at {url}, refreshing token..." ) - if self.token_manager: + if self.token_provider: try: - new_token = await self.token_manager.refresh_on_unauthorized() + new_token = await self.token_provider.force_refresh() headers["Authorization"] = f"Bearer {new_token}" self.logger.debug(f"Retrying with refreshed token: {url}") response = await client.get(url, headers=headers, params=params) diff --git a/backend/airweave/platform/sync/token_manager.py b/backend/airweave/platform/sync/token_manager.py deleted file mode 100644 index 786d4757d..000000000 --- a/backend/airweave/platform/sync/token_manager.py +++ /dev/null @@ -1,480 +0,0 @@ -"""Token manager for handling OAuth2 token refresh during sync operations.""" - -import asyncio -import time -from typing import Any, Dict, Optional - -import httpx -from sqlalchemy.ext.asyncio import AsyncSession - -import airweave.core.container as _container_module # TODO(code-blue): inject via constructor -from airweave import crud, schemas -from airweave.api.context import ApiContext -from airweave.core import credentials -from airweave.core.exceptions import TokenRefreshError -from airweave.core.logging import logger - - -class TokenManager: - """Manages OAuth2 token refresh for sources during sync operations. - - This class provides centralized token management to ensure sources always - have valid access tokens during long-running sync jobs. It handles: - - Automatic token refresh before expiry - - Concurrent refresh prevention - - Direct token injection scenarios - - Auth provider token refresh - """ - - # Token refresh interval (25 minutes to be safe with 1-hour tokens) - REFRESH_INTERVAL_SECONDS = 25 * 60 - - def __init__( - self, - db: AsyncSession, - source_short_name: str, - source_connection: schemas.SourceConnection, - ctx: ApiContext, - initial_credentials: Any, - is_direct_injection: bool = False, - logger_instance=None, - auth_provider_instance: Optional[Any] = None, - ): - """Initialize the token manager. - - Args: - db: Database session - source_short_name: Short name of the source - source_connection: Source connection configuration - ctx: The API context - initial_credentials: The initial credentials (dict, string token, or auth config object) - is_direct_injection: Whether token was directly injected (no refresh) - logger_instance: Optional logger instance for contextual logging - auth_provider_instance: Optional auth provider instance for token refresh - """ - self.db = db - self.source_short_name = source_short_name - self.connection_id = source_connection.id - self.integration_credential_id = source_connection.integration_credential_id - self.ctx = ctx - - self.is_direct_injection = is_direct_injection - self.logger = logger_instance or logger - - # Auth provider instance - self.auth_provider_instance = auth_provider_instance - - # NEW: Store config fields for token refresh (needed for templated backend URLs) - self.config_fields = getattr(source_connection, "config_fields", None) - - # Log if config_fields available - if self.config_fields and self.logger: - self.logger.debug( - f"TokenManager initialized with config_fields: {list(self.config_fields.keys())}" - ) - - # Extract the token from credentials - self._current_token = self._extract_token_from_credentials(initial_credentials) - if not self._current_token: - raise ValueError( - f"No token found in credentials for source '{source_short_name}'. " - f"TokenManager requires a token to manage." - ) - - # Check if credentials have a refresh token (needed for refresh capability check) - self._has_refresh_token = self._check_has_refresh_token(initial_credentials) - - # Set last refresh time to 0 to force an immediate refresh on first use - # This ensures we always start a sync with a fresh token, even if the stored - # token was issued hours/days ago and has since expired - self._last_refresh_time = 0 - self._refresh_lock = asyncio.Lock() - - # Cache for tokens obtained for alternative resource scopes - # (e.g. SharePoint REST API token vs Graph API token) - # Stores (token, fetch_timestamp) tuples for TTL enforcement - self._resource_tokens: Dict[str, tuple] = {} - - # For sources without refresh tokens, we can't refresh - self._can_refresh = self._determine_refresh_capability() - - def _determine_refresh_capability(self) -> bool: - """Determine if this source supports token refresh.""" - # Direct injection tokens should not be refreshed - if self.is_direct_injection: - self.logger.debug( - f"Token refresh disabled for {self.source_short_name}: direct injection mode" - ) - return False - - # If auth provider instance is available, we can always refresh through it - if self.auth_provider_instance: - return True - - # Check if credentials contain a refresh token - # This handles OAuthTokenAuthentication where user provided access_token only - if not self._has_refresh_token: - self.logger.debug( - f"Token refresh disabled for {self.source_short_name}: no refresh " - "token in credentials" - ) - return False - - # For standard OAuth with refresh token, refresh is possible - return True - - def _check_has_refresh_token(self, credentials: Any) -> bool: - """Check if credentials contain a refresh token. - - This is used to determine if token refresh is possible for OAuth sources - created via direct token injection (OAuthTokenAuthentication). - """ - if isinstance(credentials, dict): - refresh_token = credentials.get("refresh_token") - return bool(refresh_token and str(refresh_token).strip()) - - if hasattr(credentials, "refresh_token"): - refresh_token = credentials.refresh_token - return bool(refresh_token and str(refresh_token).strip()) - - return False - - async def get_valid_token(self) -> str: - """Get a valid access token, refreshing if necessary. - - This method ensures the token is fresh and handles refresh logic - with proper concurrency control. - - Returns: - A valid access token - - Raises: - TokenRefreshError: If token refresh fails - """ - # If we can't refresh, just return the current token - if not self._can_refresh: - return self._current_token - - # Check if token needs refresh (proactive refresh before expiry) - current_time = time.time() - time_since_refresh = current_time - self._last_refresh_time - - if time_since_refresh < self.REFRESH_INTERVAL_SECONDS: - return self._current_token - - # Token needs refresh - use lock to prevent concurrent refreshes - async with self._refresh_lock: - # Double-check after acquiring lock (another worker might have refreshed) - current_time = time.time() - time_since_refresh = current_time - self._last_refresh_time - - if time_since_refresh < self.REFRESH_INTERVAL_SECONDS: - return self._current_token - - # Perform the refresh - if self._last_refresh_time == 0: - self.logger.info( - f"🔄 Performing initial token refresh for {self.source_short_name} " - f"(ensuring fresh token at sync start)" - ) - else: - self.logger.debug( - f"Refreshing token for {self.source_short_name} " - f"(last refresh: {time_since_refresh:.0f}s ago)" - ) - - try: - new_token = await self._refresh_token() - self._current_token = new_token - self._last_refresh_time = current_time - - self.logger.debug(f"Successfully refreshed token for {self.source_short_name}") - return new_token - - except Exception as e: - self.logger.warning( - f"Token refresh failed for {self.source_short_name}, " - f"falling back to current token: {str(e)}" - ) - self._can_refresh = False - return self._current_token - - async def refresh_on_unauthorized(self) -> str: - """Force a token refresh after receiving an unauthorized error. - - This method is called when a source receives a 401 error, indicating - the token has expired unexpectedly. - - Returns: - A fresh access token - - Raises: - TokenRefreshError: If token refresh fails or is not supported - """ - if not self._can_refresh: - raise TokenRefreshError(f"Token refresh not supported for {self.source_short_name}") - - async with self._refresh_lock: - self.logger.warning( - f"Forcing token refresh for {self.source_short_name} due to 401 error" - ) - - try: - new_token = await self._refresh_token() - self._current_token = new_token - self._last_refresh_time = time.time() - self._resource_tokens.clear() - - self.logger.debug( - f"Successfully refreshed token for {self.source_short_name} after 401" - ) - return new_token - - except Exception as e: - self.logger.error( - f"Failed to refresh token for {self.source_short_name} after 401: {str(e)}" - ) - raise TokenRefreshError(f"Token refresh failed after 401: {str(e)}") from e - - async def get_token_for_resource(self, resource_scope: str) -> str: - """Get a token for a different resource scope using the stored refresh token. - - Used for cross-resource access, e.g. obtaining a SharePoint REST API token - when the primary token is scoped to Microsoft Graph. - - Args: - resource_scope: The target scope, e.g. "https://tenant.sharepoint.com/.default" - - Returns: - An access token scoped to the requested resource. - - Raises: - TokenRefreshError: If the token exchange fails. - """ - cache_key = resource_scope.lower() - if cache_key in self._resource_tokens: - cached_token, fetch_time = self._resource_tokens[cache_key] - if (time.time() - fetch_time) < self.REFRESH_INTERVAL_SECONDS: - return cached_token - self.logger.debug(f"Resource token for {resource_scope} expired, refreshing") - del self._resource_tokens[cache_key] - - if not self._has_refresh_token: - raise TokenRefreshError( - f"Cannot get token for resource {resource_scope}: no refresh token available" - ) - - try: - from airweave.db.session import get_db_context - from airweave.platform.auth.settings import integration_settings - - async with get_db_context() as refresh_db: - credential = await crud.integration_credential.get( - refresh_db, self.integration_credential_id, self.ctx - ) - if not credential: - raise TokenRefreshError("Integration credential not found") - - decrypted_credential = credentials.decrypt(credential.encrypted_credentials) - refresh_token = decrypted_credential.get("refresh_token") - if not refresh_token: - raise TokenRefreshError("No refresh token for resource token exchange") - - config = await integration_settings.get_by_short_name(self.source_short_name) - - async with httpx.AsyncClient() as client: - response = await client.post( - config.backend_url, - data={ - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": config.client_id, - "client_secret": config.client_secret, - "scope": resource_scope, - }, - timeout=30.0, - ) - response.raise_for_status() - data = response.json() - - token = data.get("access_token") - if not token: - raise TokenRefreshError( - f"No access_token in response for resource scope {resource_scope}" - ) - - self._resource_tokens[cache_key] = (token, time.time()) - self.logger.info( - f"Obtained token for resource scope {resource_scope} " - f"(source: {self.source_short_name})" - ) - return token - - except TokenRefreshError: - raise - except httpx.HTTPStatusError as e: - self.logger.error( - f"Resource token exchange failed ({e.response.status_code}): " - f"{e.response.text[:300]}" - ) - raise TokenRefreshError( - f"Resource token exchange failed: {e.response.status_code}" - ) from e - except Exception as e: - raise TokenRefreshError(f"Resource token exchange failed: {str(e)}") from e - - async def _refresh_token(self) -> str: - """Internal method to perform the actual token refresh. - - Returns: - The new access token - - Raises: - Exception: If refresh fails - """ - # If auth provider instance is available, refresh through it - if self.auth_provider_instance: - return await self._refresh_via_auth_provider() - - # Otherwise use standard OAuth refresh - return await self._refresh_via_oauth() - - async def _refresh_via_auth_provider(self) -> str: - """Refresh token using auth provider instance. - - Returns: - The new access token - - Raises: - TokenRefreshError: If refresh fails - """ - self.logger.debug( - f"Refreshing token via auth provider instance for source '{self.source_short_name}'" - ) - - try: - if _container_module.container is None: - raise RuntimeError("Container not initialized") - entry = _container_module.container.source_registry.get(self.source_short_name) - - # Get fresh credentials from auth provider instance - fresh_credentials = await self.auth_provider_instance.get_creds_for_source( - source_short_name=self.source_short_name, - source_auth_config_fields=entry.runtime_auth_all_fields, - optional_fields=entry.runtime_auth_optional_fields, - ) - - # Extract access token - access_token = fresh_credentials.get("access_token") - if not access_token: - raise TokenRefreshError("No access token in credentials from auth provider") - - # Update the stored credentials in the database - if self.integration_credential_id: - credential_update = schemas.IntegrationCredentialUpdate( - encrypted_credentials=credentials.encrypt(fresh_credentials) - ) - - # Use a separate database session for the update to avoid transaction issues - from airweave.db.session import get_db_context - - try: - async with get_db_context() as update_db: - # Get the credential in the new session - credential = await crud.integration_credential.get( - update_db, self.integration_credential_id, self.ctx - ) - if credential: - await crud.integration_credential.update( - update_db, - db_obj=credential, - obj_in=credential_update, - ctx=self.ctx, - ) - except Exception as db_error: - self.logger.error(f"Failed to update credentials in database: {str(db_error)}") - # Continue anyway - we have the token, just couldn't persist it - - return access_token - - except Exception as e: - # Ensure the main session is rolled back if it's in a bad state - try: - await self.db.rollback() - except Exception: - # Session might not be in a transaction, that's OK - pass - - self.logger.error(f"Failed to refresh token via auth provider instance: {str(e)}") - raise TokenRefreshError(f"Auth provider refresh failed: {str(e)}") from e - - async def _refresh_via_oauth(self) -> str: - """Refresh token using standard OAuth flow. - - Returns: - The new access token - - Raises: - TokenRefreshError: If refresh fails - """ - try: - # Use a separate database session to avoid transaction issues - from airweave.db.session import get_db_context - - async with get_db_context() as refresh_db: - # Get the stored credentials - if not self.integration_credential_id: - raise TokenRefreshError("No integration credential found for token refresh") - - credential = await crud.integration_credential.get( - refresh_db, self.integration_credential_id, self.ctx - ) - if not credential: - raise TokenRefreshError("Integration credential not found") - - decrypted_credential = credentials.decrypt(credential.encrypted_credentials) - - oauth2_response = ( - await _container_module.container.oauth2_service.refresh_access_token( - db=refresh_db, - integration_short_name=self.source_short_name, - ctx=self.ctx, - connection_id=self.connection_id, - decrypted_credential=decrypted_credential, - config_fields=self.config_fields, - ) - ) - - return oauth2_response.access_token - - except Exception as e: - # Ensure the main session is rolled back if it's in a bad state - try: - await self.db.rollback() - except Exception: - # Session might not be in a transaction, that's OK - pass - - # Re-raise the original error - if isinstance(e, TokenRefreshError): - raise - raise TokenRefreshError(f"OAuth refresh failed: {str(e)}") from e - - def _extract_token_from_credentials(self, credentials: Any) -> Optional[str]: - """Extract OAuth access token from credentials. - - This method only handles OAuth tokens, not API keys or other auth types. - """ - # If it's already a string, assume it's the token - if isinstance(credentials, str): - return credentials - - # If it's a dict, look for access_token (OAuth standard) - if isinstance(credentials, dict): - return credentials.get("access_token") - - # If it's an object with attributes, try to get access_token - if hasattr(credentials, "access_token"): - return credentials.access_token - - return None diff --git a/backend/tests/unit/platform/sync/test_token_manager.py b/backend/tests/unit/platform/sync/test_token_manager.py deleted file mode 100644 index ba167d164..000000000 --- a/backend/tests/unit/platform/sync/test_token_manager.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Unit tests for TokenManager. - -Tests the token manager's ability to: -- Detect refresh token presence in credentials -- Determine refresh capability based on credentials -- Handle direct token injection (access_token only, no refresh_token) -""" - -import pytest -from unittest.mock import MagicMock, AsyncMock -from uuid import uuid4 - -from airweave.platform.sync.token_manager import TokenManager - - -class TestCheckHasRefreshToken: - """Tests for TokenManager._check_has_refresh_token method.""" - - def _create_minimal_token_manager(self, credentials): - """Create a TokenManager instance for testing _check_has_refresh_token.""" - # Create mock dependencies - mock_db = MagicMock() - mock_source_connection = MagicMock() - mock_source_connection.id = uuid4() - mock_source_connection.integration_credential_id = uuid4() - mock_source_connection.config_fields = None - mock_ctx = MagicMock() - mock_ctx.logger = MagicMock() - - # TokenManager constructor will call _check_has_refresh_token - # We need to ensure credentials has access_token - if isinstance(credentials, dict) and "access_token" not in credentials: - credentials["access_token"] = "test_access_token" - elif not isinstance(credentials, dict): - # For non-dict credentials, we'll test the method directly - pass - - manager = TokenManager( - db=mock_db, - source_short_name="test_source", - source_connection=mock_source_connection, - ctx=mock_ctx, - initial_credentials=credentials if isinstance(credentials, dict) else {"access_token": "test"}, - is_direct_injection=False, - logger_instance=MagicMock(), - ) - - return manager - - def test_dict_with_refresh_token(self): - """Test detection of refresh token in dict credentials.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": "test_refresh_token", - } - manager = self._create_minimal_token_manager(credentials) - assert manager._has_refresh_token is True - - def test_dict_without_refresh_token(self): - """Test no refresh token in dict credentials (direct token injection).""" - credentials = { - "access_token": "test_access_token", - } - manager = self._create_minimal_token_manager(credentials) - assert manager._has_refresh_token is False - - def test_dict_with_empty_refresh_token(self): - """Test empty refresh token is treated as no refresh token.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": "", - } - manager = self._create_minimal_token_manager(credentials) - assert manager._has_refresh_token is False - - def test_dict_with_whitespace_refresh_token(self): - """Test whitespace-only refresh token is treated as no refresh token.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": " ", - } - manager = self._create_minimal_token_manager(credentials) - assert manager._has_refresh_token is False - - def test_dict_with_none_refresh_token(self): - """Test None refresh token is treated as no refresh token.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": None, - } - manager = self._create_minimal_token_manager(credentials) - assert manager._has_refresh_token is False - - -class TestDetermineRefreshCapability: - """Tests for TokenManager._determine_refresh_capability method.""" - - def _create_token_manager( - self, - credentials, - is_direct_injection=False, - auth_provider_instance=None, - ): - """Create a TokenManager instance for testing _determine_refresh_capability.""" - mock_db = MagicMock() - mock_source_connection = MagicMock() - mock_source_connection.id = uuid4() - mock_source_connection.integration_credential_id = uuid4() - mock_source_connection.config_fields = None - mock_ctx = MagicMock() - mock_ctx.logger = MagicMock() - - manager = TokenManager( - db=mock_db, - source_short_name="test_source", - source_connection=mock_source_connection, - ctx=mock_ctx, - initial_credentials=credentials, - is_direct_injection=is_direct_injection, - logger_instance=MagicMock(), - auth_provider_instance=auth_provider_instance, - ) - - return manager - - def test_direct_injection_disables_refresh(self): - """Test that direct injection flag disables refresh capability.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": "test_refresh_token", # Even with refresh token - } - manager = self._create_token_manager( - credentials=credentials, - is_direct_injection=True, # Direct injection - ) - assert manager._can_refresh is False - - def test_auth_provider_enables_refresh(self): - """Test that auth provider instance enables refresh capability.""" - credentials = { - "access_token": "test_access_token", - # No refresh token - } - mock_auth_provider = MagicMock() - manager = self._create_token_manager( - credentials=credentials, - auth_provider_instance=mock_auth_provider, - ) - assert manager._can_refresh is True - - def test_no_refresh_token_disables_refresh(self): - """Test that missing refresh token disables refresh capability. - - This is the key fix for direct token injection via OAuthTokenAuthentication. - """ - credentials = { - "access_token": "test_access_token", - # No refresh token - simulates OAuthTokenAuthentication - } - manager = self._create_token_manager( - credentials=credentials, - is_direct_injection=False, # Not flagged as direct injection - auth_provider_instance=None, # No auth provider - ) - # Should detect missing refresh token and disable refresh - assert manager._can_refresh is False - - def test_refresh_token_present_enables_refresh(self): - """Test that present refresh token enables refresh capability.""" - credentials = { - "access_token": "test_access_token", - "refresh_token": "test_refresh_token", - } - manager = self._create_token_manager( - credentials=credentials, - is_direct_injection=False, - auth_provider_instance=None, - ) - # Should detect refresh token and enable refresh - assert manager._can_refresh is True - - -class TestGetValidToken: - """Tests for TokenManager.get_valid_token method.""" - - @pytest.mark.asyncio - async def test_returns_current_token_when_refresh_disabled(self): - """Test that get_valid_token returns current token when refresh is disabled.""" - credentials = { - "access_token": "test_access_token", - # No refresh token - refresh should be disabled - } - - mock_db = MagicMock() - mock_source_connection = MagicMock() - mock_source_connection.id = uuid4() - mock_source_connection.integration_credential_id = uuid4() - mock_source_connection.config_fields = None - mock_ctx = MagicMock() - mock_ctx.logger = MagicMock() - - manager = TokenManager( - db=mock_db, - source_short_name="test_source", - source_connection=mock_source_connection, - ctx=mock_ctx, - initial_credentials=credentials, - is_direct_injection=False, - logger_instance=MagicMock(), - ) - - # Refresh should be disabled - assert manager._can_refresh is False - - # get_valid_token should return current token without attempting refresh - token = await manager.get_valid_token() - assert token == "test_access_token" - - @pytest.mark.asyncio - async def test_direct_injection_returns_token_without_refresh(self): - """Test direct injection mode returns token without refresh attempt.""" - credentials = { - "access_token": "direct_injected_token", - "refresh_token": "should_not_be_used", - } - - mock_db = MagicMock() - mock_source_connection = MagicMock() - mock_source_connection.id = uuid4() - mock_source_connection.integration_credential_id = uuid4() - mock_source_connection.config_fields = None - mock_ctx = MagicMock() - mock_ctx.logger = MagicMock() - - manager = TokenManager( - db=mock_db, - source_short_name="test_source", - source_connection=mock_source_connection, - ctx=mock_ctx, - initial_credentials=credentials, - is_direct_injection=True, # Explicit direct injection - logger_instance=MagicMock(), - ) - - # Should be disabled due to direct_injection flag - assert manager._can_refresh is False - - token = await manager.get_valid_token() - assert token == "direct_injected_token" - - -class TestCredentialFormats: - """Tests for different credential formats.""" - - def _create_manager_with_credentials(self, credentials): - """Helper to create manager with various credential formats.""" - mock_db = MagicMock() - mock_source_connection = MagicMock() - mock_source_connection.id = uuid4() - mock_source_connection.integration_credential_id = uuid4() - mock_source_connection.config_fields = None - mock_ctx = MagicMock() - mock_ctx.logger = MagicMock() - - return TokenManager( - db=mock_db, - source_short_name="test_source", - source_connection=mock_source_connection, - ctx=mock_ctx, - initial_credentials=credentials, - is_direct_injection=False, - logger_instance=MagicMock(), - ) - - def test_string_credentials_no_refresh(self): - """Test that string credentials (just token) have no refresh capability.""" - # String credentials are just the access token - manager = self._create_manager_with_credentials("string_access_token") - - # String credentials don't have refresh token - assert manager._has_refresh_token is False - assert manager._can_refresh is False - - def test_object_credentials_with_refresh_token(self): - """Test object credentials with refresh_token attribute.""" - # Create an object with access_token and refresh_token attributes - class CredentialsObj: - access_token = "obj_access_token" - refresh_token = "obj_refresh_token" - - manager = self._create_manager_with_credentials(CredentialsObj()) - - assert manager._has_refresh_token is True - assert manager._can_refresh is True - - def test_object_credentials_without_refresh_token(self): - """Test object credentials without refresh_token attribute.""" - class CredentialsObj: - access_token = "obj_access_token" - # No refresh_token attribute - - manager = self._create_manager_with_credentials(CredentialsObj()) - - assert manager._has_refresh_token is False - assert manager._can_refresh is False diff --git a/backend/tests/unit/platform/sync/test_token_providers.py b/backend/tests/unit/platform/sync/test_token_providers.py new file mode 100644 index 000000000..85803dee4 --- /dev/null +++ b/backend/tests/unit/platform/sync/test_token_providers.py @@ -0,0 +1,242 @@ +"""Unit tests for TokenProvider implementations. + +Tests: +- OAuthTokenProvider: timer/cache behavior, get_token, force_refresh +- StaticTokenProvider: get_token, force_refresh raises +- AuthProviderTokenProvider: delegates to auth provider +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +from airweave.core.logging import ContextualLogger +from airweave.domains.oauth.types import RefreshResult +from airweave.domains.sources.token_providers.auth_provider import AuthProviderTokenProvider +from airweave.domains.sources.token_providers.exceptions import ( + TokenProviderError, + TokenRefreshNotSupportedError, +) +from airweave.domains.sources.token_providers.oauth import OAuthTokenProvider +from airweave.domains.sources.token_providers.static import StaticTokenProvider + + +def _mock_logger() -> ContextualLogger: + """Create a mock ContextualLogger.""" + return MagicMock(spec=ContextualLogger) + + +def _oauth_provider( + credentials="test_token", + oauth_type=None, + **overrides, +): + """Create an OAuthTokenProvider with sensible defaults. + + Args: + credentials: Raw credentials (str, dict, or object). + oauth_type: OAuth type string — use "with_refresh" to enable refresh. + """ + return OAuthTokenProvider( + credentials=credentials, + oauth_type=oauth_type, + oauth2_service=overrides.get("oauth2_service", MagicMock()), + source_short_name=overrides.get("source_short_name", "test_source"), + connection_id=overrides.get("connection_id", uuid4()), + ctx=overrides.get("ctx", MagicMock()), + logger=overrides.get("logger", _mock_logger()), + config_fields=overrides.get("config_fields", None), + ) + + +# --------------------------------------------------------------------------- +# OAuthTokenProvider +# --------------------------------------------------------------------------- + + +class TestOAuthGetToken: + """Tests for OAuthTokenProvider.get_token.""" + + @pytest.mark.asyncio + async def test_returns_token_when_no_refresh(self): + """When oauth_type doesn't support refresh, returns the initial token.""" + p = _oauth_provider("my_token", oauth_type="access_only") + assert await p.get_token() == "my_token" + + @pytest.mark.asyncio + async def test_returns_token_when_recently_refreshed(self): + """When last refresh was recent, returns cached token without calling service.""" + mock_service = MagicMock() + mock_service.refresh_and_persist = AsyncMock( + return_value=RefreshResult(access_token="refreshed", expires_in=3600) + ) + + creds = {"access_token": "initial", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="with_refresh", oauth2_service=mock_service) + + token = await p.get_token() + assert token == "refreshed" + assert mock_service.refresh_and_persist.call_count == 1 + + token2 = await p.get_token() + assert token2 == "refreshed" + assert mock_service.refresh_and_persist.call_count == 1 + + @pytest.mark.asyncio + async def test_raises_on_refresh_failure(self): + """When refresh fails, raises TokenProviderError.""" + from airweave.core.exceptions import TokenRefreshError + + mock_service = MagicMock() + mock_service.refresh_and_persist = AsyncMock( + side_effect=TokenRefreshError("network error") + ) + + creds = {"access_token": "tok", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="with_refresh", oauth2_service=mock_service) + + with pytest.raises(TokenProviderError): + await p.get_token() + + +class TestOAuthForceRefresh: + """Tests for OAuthTokenProvider.force_refresh.""" + + @pytest.mark.asyncio + async def test_raises_when_no_refresh(self): + """force_refresh raises TokenRefreshNotSupportedError when refresh not possible.""" + p = _oauth_provider("tok", oauth_type="access_only") + with pytest.raises(TokenRefreshNotSupportedError): + await p.force_refresh() + + @pytest.mark.asyncio + async def test_returns_fresh_token(self): + """force_refresh calls service and returns new token.""" + mock_service = MagicMock() + mock_service.refresh_and_persist = AsyncMock( + return_value=RefreshResult(access_token="forced_token", expires_in=3600) + ) + + creds = {"access_token": "old", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="with_refresh", oauth2_service=mock_service) + token = await p.force_refresh() + assert token == "forced_token" + + @pytest.mark.asyncio + async def test_raises_on_failure(self): + """force_refresh raises TokenProviderError when service fails.""" + from airweave.core.exceptions import TokenRefreshError + + mock_service = MagicMock() + mock_service.refresh_and_persist = AsyncMock( + side_effect=TokenRefreshError("fail") + ) + + creds = {"access_token": "old", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="with_refresh", oauth2_service=mock_service) + with pytest.raises(TokenProviderError): + await p.force_refresh() + + +class TestOAuthConstructor: + """Tests for OAuthTokenProvider credential handling.""" + + def test_extracts_token_from_string(self): + p = _oauth_provider("raw_token") + assert p._token == "raw_token" + + def test_extracts_token_from_dict(self): + p = _oauth_provider({"access_token": "dict_tok"}) + assert p._token == "dict_tok" + + def test_extracts_token_from_object(self): + class C: + access_token = "obj_tok" + p = _oauth_provider(C()) + assert p._token == "obj_tok" + + def test_raises_on_missing_token(self): + with pytest.raises(ValueError, match="No access token"): + _oauth_provider({"not_a_token": "x"}) + + def test_can_refresh_when_type_and_token_present(self): + creds = {"access_token": "at", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="with_refresh") + assert p._can_refresh is True + + def test_no_refresh_when_type_is_access_only(self): + creds = {"access_token": "at", "refresh_token": "rt"} + p = _oauth_provider(creds, oauth_type="access_only") + assert p._can_refresh is False + + def test_no_refresh_when_no_refresh_token(self): + creds = {"access_token": "at"} + p = _oauth_provider(creds, oauth_type="with_refresh") + assert p._can_refresh is False + + def test_no_refresh_when_refresh_token_empty(self): + creds = {"access_token": "at", "refresh_token": " "} + p = _oauth_provider(creds, oauth_type="with_refresh") + assert p._can_refresh is False + + +# --------------------------------------------------------------------------- +# StaticTokenProvider +# --------------------------------------------------------------------------- + + +class TestStaticTokenProvider: + """Tests for StaticTokenProvider.""" + + @pytest.mark.asyncio + async def test_get_token_returns_value(self): + p = StaticTokenProvider("api_key_123") + assert await p.get_token() == "api_key_123" + + @pytest.mark.asyncio + async def test_force_refresh_raises(self): + p = StaticTokenProvider("api_key_123", source_short_name="attio") + with pytest.raises(TokenRefreshNotSupportedError): + await p.force_refresh() + + def test_empty_token_raises(self): + with pytest.raises(ValueError): + StaticTokenProvider("") + + +# --------------------------------------------------------------------------- +# AuthProviderTokenProvider +# --------------------------------------------------------------------------- + + +class TestAuthProviderTokenProvider: + """Tests for AuthProviderTokenProvider.""" + + def _make_provider(self, access_token: str = "fresh_token"): + mock_auth_provider = MagicMock() + mock_auth_provider.get_creds_for_source = AsyncMock( + return_value={"access_token": access_token} + ) + + mock_registry = MagicMock() + entry = MagicMock() + entry.runtime_auth_all_fields = ["access_token"] + entry.runtime_auth_optional_fields = [] + mock_registry.get.return_value = entry + + return AuthProviderTokenProvider( + auth_provider_instance=mock_auth_provider, + source_short_name="test_source", + source_registry=mock_registry, + logger=_mock_logger(), + ) + + @pytest.mark.asyncio + async def test_get_token_delegates_to_provider(self): + p = self._make_provider("fresh_at") + assert await p.get_token() == "fresh_at" + + @pytest.mark.asyncio + async def test_force_refresh_same_as_get_token(self): + p = self._make_provider("refreshed_at") + assert await p.force_refresh() == "refreshed_at"