Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions .cursor/rules/auth-providers.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -50,51 +50,46 @@ SLUG_NAME_MAPPING = {
```


## Integration with TokenManager
## Integration with TokenProviderProtocol

The `TokenManager` class orchestrates token refresh during long-running syncs:
Auth providers integrate via `AuthProviderTokenProvider`, one of three `TokenProviderProtocol` implementations:

### Initialization
### TokenProviderProtocol
```python
# SyncFactory creates TokenManager with optional auth provider
token_manager = TokenManager(
db=db,
source_connection=connection,
auth_provider_instance=auth_provider, # Optional
initial_credentials=credentials
class TokenProviderProtocol(Protocol):
async def get_token(self) -> str: ...
async def force_refresh(self) -> str: ...
```

### AuthProviderTokenProvider
Created by `SourceLifecycleService._configure_token_provider()` when a source connection has an `auth_provider_connection_id`.

```python
# Lifecycle builds the provider with the auth provider instance
provider = AuthProviderTokenProvider(
auth_provider=auth_provider_instance,
source_short_name="slack",
auth_config_fields=["access_token", "refresh_token"],
logger=logger,
)
```

### Refresh Flow
1. **Check refresh capability** (`_determine_refresh_capability()`):
- Direct injection tokens → no refresh
- Auth provider present → always refreshable
- Standard OAuth → attempt refresh

2. **Proactive refresh** (`get_valid_token()`):
- Refreshes tokens every 25 minutes (before 1-hour expiry)
- Uses async lock to prevent concurrent refreshes
- Falls back to stored token if refresh fails

3. **Auth provider refresh** (`_refresh_via_auth_provider()`):
```python
# TokenManager calls auth provider for fresh credentials
fresh_creds = await auth_provider.get_creds_for_source(
source_short_name="slack",
source_auth_config_fields=["access_token", "refresh_token"]
)
# Updates database with new credentials
await crud.integration_credential.update(db, credential, fresh_creds)
```

4. **Fallback OAuth refresh** (`_refresh_via_oauth()`):
- Uses oauth2_service if no auth provider
- Creates separate DB session to avoid transaction issues

### Credential Priority (in SyncFactory)
1. Direct token injection (highest)
2. Auth provider instance
3. Database credentials with OAuth refresh
#### Refresh Flow
1. **`get_token()`** — calls `auth_provider.get_creds_for_source()` each time to fetch fresh credentials from the external service (Pipedream/Composio)
2. **`force_refresh()`** — same as `get_token()` (auth providers always return fresh creds)
3. **Retry** — uses tenacity (3 attempts, exponential backoff) on `AuthProviderServerError` / `AuthProviderRateLimitError`
4. **Error translation** — `AuthProviderError` subtypes are mapped to `TokenProviderError` subtypes:
- `AuthProviderAuthError` → `TokenCredentialsInvalidError`
- `AuthProviderAccountNotFoundError` → `TokenProviderAccountGoneError`
- `AuthProviderMissingFieldsError` → `TokenProviderMissingCredsError`
- `AuthProviderConfigError` → `TokenProviderConfigError`
- `AuthProviderRateLimitError` → `TokenProviderRateLimitError`
- `AuthProviderServerError` → `TokenProviderServerError`

### Provider Resolution (in SourceLifecycleService)
1. Auth provider connection present → `AuthProviderTokenProvider`
2. OAuth credentials with `oauth_type` → `OAuthTokenProvider`
3. Direct token injection → `StaticTokenProvider`

## Database Schema
- **auth_providers**: Provider definitions from decorators
Expand Down
87 changes: 39 additions & 48 deletions .cursor/rules/source-connector-implementation.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential

from airweave.core.exceptions import TokenRefreshError
from airweave.domains.sources.exceptions import SourceAuthError, SourceRateLimitError, SourceServerError
from airweave.platform.decorators import source
from airweave.platform.entities._base import Breadcrumb, ChunkEntity
from airweave.platform.entities.{short_name} import (
Expand Down Expand Up @@ -436,21 +436,24 @@ class MyConnectorSource(BaseSource):

@classmethod
async def create(
cls, access_token: str, config: Optional[Dict[str, Any]] = None
cls,
*,
credentials: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "MyConnectorSource":
"""Create and configure the source.

Args:
access_token: OAuth access token or API key
credentials: Decrypted auth credentials (access_token, refresh_token, etc.)
config: Optional configuration (e.g., workspace filters)
**kwargs: Additional keyword arguments (token_provider, logger, etc.)

Returns:
Configured source instance
"""
instance = cls()
instance.access_token = access_token

# Store config as instance attributes
if config:
instance.workspace_id = config.get("workspace_id")
instance.exclude_pattern = config.get("exclude_pattern", "")
Expand All @@ -464,13 +467,12 @@ class MyConnectorSource(BaseSource):
"""Generate all entities from the source.

This is the main entry point called by the sync engine.
Token is obtained via self.get_access_token() (delegates to the injected TokenProvider).
"""
async with self.http_client() as client:
# Generate entities hierarchically
async for top_level in self._generate_top_level(client):
yield top_level

# Generate children with breadcrumb tracking
async for child in self._generate_children(client, top_level):
yield child

Expand All @@ -491,20 +493,21 @@ class MyConnectorSource(BaseSource):

#### 1. The `create()` Classmethod

This is called once when a sync starts:
This is called once when a sync starts. The `SourceLifecycleService` passes keyword arguments:

```python
@classmethod
async def create(
cls, access_token: str, config: Optional[Dict[str, Any]] = None
cls,
*,
credentials: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "MyConnectorSource":
"""Create and configure the source."""
instance = cls()
instance.access_token = access_token

# Parse config fields
if config:
# Store as instance attributes for use in generate_entities()
instance.workspace_filter = config.get("workspace_filter", "")
instance.include_archived = config.get("include_archived", False)
else:
Expand All @@ -514,6 +517,8 @@ async def create(
return instance
```

**Note:** Do NOT store `access_token` as an instance attribute. Tokens are managed by the injected `TokenProvider` and accessed via `self.get_access_token()`.

#### 2. The `generate_entities()` Method

This is an async generator that yields entities:
Expand All @@ -536,7 +541,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]:
workspace_breadcrumb = Breadcrumb(
entity_id=workspace.entity_id,
name=workspace.name,
type="workspace"
entity_type="WorkspaceEntity",
)

# Child entities
Expand All @@ -546,7 +551,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]:
project_breadcrumb = Breadcrumb(
entity_id=project.entity_id,
name=project.name,
type="project"
entity_type="ProjectEntity",
)
breadcrumbs = [workspace_breadcrumb, project_breadcrumb]

Expand All @@ -557,7 +562,7 @@ async def generate_entities(self) -> AsyncGenerator[ChunkEntity, None]:

#### 3. Making API Requests with Token Refresh

Always use this pattern for authenticated requests:
Use `self.get_access_token()` (delegates to the injected `TokenProvider`) and `self.refresh_on_unauthorized()` for 401 recovery:

```python
@retry(
Expand All @@ -571,55 +576,41 @@ async def _get_with_auth(
url: str,
params: Optional[Dict[str, Any]] = None
) -> Dict:
"""Make authenticated GET request with automatic token refresh.

This method handles:
- Token refresh on 401 errors
- Retries with exponential backoff
- Proper error logging
"""
# Get a valid token (will refresh if needed)
"""Make authenticated GET request with automatic token refresh."""
access_token = await self.get_access_token()
if not access_token:
raise ValueError("No access token available")
raise SourceAuthError(source=self.short_name, message="No access token available")

headers = {"Authorization": f"Bearer {access_token}"}

try:
response = await client.get(url, headers=headers, params=params)

# Handle 401 Unauthorized - token might have expired
if response.status_code == 401:
self.logger.warning(f"Received 401 for {url}, refreshing token...")

if self.token_manager:
try:
# Force refresh the token
new_token = await self.token_manager.refresh_on_unauthorized()
headers = {"Authorization": f"Bearer {new_token}"}

# Retry with new token
self.logger.info(f"Retrying with refreshed token: {url}")
response = await client.get(url, headers=headers, params=params)

except TokenRefreshError as e:
self.logger.error(f"Failed to refresh token: {str(e)}")
response.raise_for_status()
else:
self.logger.error("No token manager available")
response.raise_for_status()
self.logger.warning(f"Received 401 for {url}, forcing token refresh...")
new_token = await self.refresh_on_unauthorized()
headers = {"Authorization": f"Bearer {new_token}"}
response = await client.get(url, headers=headers, params=params)

if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
raise SourceRateLimitError(
source=self.short_name,
retry_after=int(retry_after) if retry_after else None,
)

response.raise_for_status()
return response.json()

except httpx.HTTPStatusError as e:
if e.response.status_code >= 500:
raise SourceServerError(source=self.short_name, status_code=e.response.status_code)
self.logger.error(f"HTTP error: {e.response.status_code} for {url}")
raise
except Exception as e:
self.logger.error(f"Unexpected error: {url}, {str(e)}")
raise
```

**Note:** `self.refresh_on_unauthorized()` calls `self._token_provider.force_refresh()` under the hood. If the token provider doesn't support refresh (e.g., `StaticTokenProvider`), it raises `TokenRefreshNotSupportedError`.

### Handling Hierarchical Data

Use breadcrumbs to track entity relationships:
Expand Down Expand Up @@ -1049,7 +1040,7 @@ async for task in self._generate_tasks(client, project, breadcrumbs):
task_breadcrumb = Breadcrumb(
entity_id=task.entity_id,
name=task.name,
type="task"
entity_type="TaskEntity",
)
task_breadcrumbs = [*breadcrumbs, task_breadcrumb]

Expand Down Expand Up @@ -1170,7 +1161,7 @@ async def _generate_projects(self, client, workspace):
- [ ] Auth config class added to `platform/configs/auth.py`
- [ ] Auth config referenced in source `@source` decorator
- [ ] Source implements `create()`, `generate_entities()`, and `validate()`
- [ ] Token refresh is handled via `_get_with_auth()` pattern
- [ ] Token refresh handled via `self.get_access_token()` + `self.refresh_on_unauthorized()` pattern
- [ ] File entities use `process_file_entity()`
- [ ] Logging uses proper levels (INFO for milestones, DEBUG for details)
- [ ] OAuth config is in `dev.integrations.yaml` (human already set this up)
Expand Down
Loading
Loading