Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
2956815
Enable OAuth browser flow for GitHub connector
felixschmetz Mar 9, 2026
8c3ac6f
Fix OAuth callback token validation for GitHub
felixschmetz Mar 9, 2026
a40f465
Exclude local_storage from uvicorn reload watch
felixschmetz Mar 9, 2026
7ff8ec9
Respect backend auth_methods ordering in frontend UI
felixschmetz Mar 9, 2026
f4a0eab
Refactor GitHub dual-auth: separate token validation from source crea…
felixschmetz Mar 9, 2026
97d86c7
Simplify GitHubAuthConfig: collapse two token fields into single `tok…
felixschmetz Mar 9, 2026
bffd35c
Remove validate_token; use SourceLifecycleService.validate for OAuth …
felixschmetz Mar 9, 2026
d427761
Fix lifecycle.validate: convert dict credentials to auth config befor…
felixschmetz Mar 9, 2026
5f0ba64
Add test coverage for GitHub OAuth paths: auth config, lifecycle vali…
felixschmetz Mar 9, 2026
4d7c191
Rename GitHubSource.personal_access_token to .token
felixschmetz Mar 9, 2026
0cd9af3
Add --reload-exclude local_storage to VSCode FastAPI launch config
felixschmetz Mar 9, 2026
2a18b7d
Fix import sorting in test_callback_service and test_lifecycle
felixschmetz Mar 9, 2026
4fe316b
Fix lint errors in test files (I001, F401, F841, E501)
felixschmetz Mar 9, 2026
6cd68b0
Update auth provider field mappings and monke bongo for token rename
felixschmetz Mar 9, 2026
273b98a
Normalize access_token → token in monke ComposioBroker credentials
felixschmetz Mar 9, 2026
58ff9bf
Update form-validation cursor rule for gho_ token prefix and token fi…
felixschmetz Mar 9, 2026
6fd0dc4
Fix token validation: pass raw string to lifecycle.validate, not dict
felixschmetz Mar 9, 2026
10f3d6f
Format test files with ruff
felixschmetz Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .cursor/rules/form-validation.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Minimal, non-intrusive validation that guides without blocking. No success messa
#### Source-Specific Validators

**GitHub**
- **`githubTokenValidation`**: Validates format (ghp_, github_pat_, or 40-char hex)
- **`githubTokenValidation`**: Validates format (ghp_, github_pat_, gho_, or 40-char hex)
- **`repoNameValidation`**: Enforces owner/repo format (e.g., "airweave-ai/airweave")

**Stripe**
Expand Down Expand Up @@ -91,7 +91,7 @@ The `getAuthFieldValidation(fieldType: string, sourceShortName?: string)` functi
- **API Keys & Tokens**
- `api_key` → Generic API key validation (placeholder detection)
- `token`, `access_token` → API key validation
- `personal_access_token` → GitHub token validation
- `personal_access_token`, `token` → GitHub token validation

- **URLs**
- `url`, `endpoint`, `base_url`, `cluster_url`, `uri` → URL validation (requires http:// or https://)
Expand Down
4 changes: 4 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"--reload",
"--reload-dir",
"backend/airweave",
"--reload-exclude",
"local_storage",
"--reload-exclude",
"backend/local_storage",
"--host",
"127.0.0.1",
"--port",
Expand Down
1 change: 1 addition & 0 deletions backend/airweave/core/container/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def create_container(settings: Settings) -> Container:
init_session_repo=init_session_repo,
response_builder=sync_deps["response_builder"],
source_registry=source_deps["source_registry"],
source_lifecycle=source_deps["source_lifecycle_service"],
sync_lifecycle=sync_deps["sync_lifecycle"],
sync_record_service=sync_deps["sync_record_service"],
temporal_workflow_service=sync_deps["temporal_workflow_service"],
Expand Down
33 changes: 18 additions & 15 deletions backend/airweave/domains/oauth/callback_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
ResponseBuilderProtocol,
SourceConnectionRepositoryProtocol,
)
from airweave.domains.sources.protocols import SourceRegistryProtocol
from airweave.domains.sources.exceptions import (
SourceCreationError,
SourceNotFoundError,
SourceValidationError,
)
from airweave.domains.sources.protocols import (
SourceLifecycleServiceProtocol,
SourceRegistryProtocol,
)
from airweave.domains.sources.types import SourceRegistryEntry
from airweave.domains.syncs.protocols import (
SyncJobRepositoryProtocol,
Expand Down Expand Up @@ -78,6 +86,7 @@ def __init__(
init_session_repo: OAuthInitSessionRepositoryProtocol,
response_builder: ResponseBuilderProtocol,
source_registry: SourceRegistryProtocol,
source_lifecycle: SourceLifecycleServiceProtocol,
sync_lifecycle: SyncLifecycleServiceProtocol,
sync_record_service: SyncRecordServiceProtocol,
temporal_workflow_service: TemporalWorkflowServiceProtocol,
Expand All @@ -97,6 +106,7 @@ def __init__(
self._init_session_repo = init_session_repo
self._response_builder = response_builder
self._source_registry = source_registry
self._source_lifecycle = source_lifecycle
self._sync_lifecycle = sync_lifecycle
self._sync_record_service = sync_record_service
self._temporal_workflow_service = temporal_workflow_service
Expand Down Expand Up @@ -181,7 +191,6 @@ async def complete_oauth2_callback(
await self._validate_oauth2_token_or_raise(
source=source,
access_token=token_response.access_token,
ctx=ctx,
)

source_conn = await self._complete_oauth2_connection(
Expand Down Expand Up @@ -548,24 +557,18 @@ async def _validate_oauth2_token_or_raise(
*,
source: Source | None,
access_token: str,
ctx: ApiContext,
) -> None:
"""Validate OAuth2 token using source implementation; fail callback if invalid."""
"""Validate OAuth2 token via SourceLifecycleService.validate (create → validate)."""
if not source:
return

try:
source_cls = self._source_registry.get(source.short_name).source_class_ref

source_instance = await source_cls.create(access_token=access_token, config=None)
source_instance.set_logger(ctx.logger)

if hasattr(source_instance, "validate"):
is_valid = await source_instance.validate()
if not is_valid:
raise HTTPException(status_code=400, detail="OAuth token is invalid")
except HTTPException:
raise
await self._source_lifecycle.validate(
source.short_name,
access_token,
)
except (SourceNotFoundError, SourceCreationError, SourceValidationError) as e:
raise HTTPException(status_code=400, detail=f"Token validation failed: {e}") from e
except Exception as e:
raise HTTPException(status_code=400, detail=f"Token validation failed: {e}") from e

Expand Down
1 change: 1 addition & 0 deletions backend/airweave/domains/oauth/oauth2_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ async def _exchange_code(
"""
headers = {
"Content-Type": integration_config.content_type,
"Accept": "application/json",
}

payload = {
Expand Down
95 changes: 41 additions & 54 deletions backend/airweave/domains/oauth/tests/test_callback_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FakeOAuthInitSessionRepository,
FakeOAuthSourceRepository,
)
from airweave.domains.oauth.types import OAuth1TokenResponse
from airweave.domains.organizations.fakes.repository import FakeOrganizationRepository
from airweave.domains.source_connections.fakes.repository import FakeSourceConnectionRepository
from airweave.domains.syncs.fakes.sync_job_repository import FakeSyncJobRepository
Expand All @@ -32,9 +33,8 @@
from airweave.models.organization import Organization
from airweave.models.source_connection import SourceConnection
from airweave.platform.auth.schemas import OAuth2TokenResponse
from airweave.domains.oauth.types import OAuth1TokenResponse
from airweave.schemas.source_connection import AuthenticationMethod
from airweave.schemas.organization import Organization as OrganizationSchema
from airweave.schemas.source_connection import AuthenticationMethod

NOW = datetime.now(timezone.utc)
ORG_ID = uuid4()
Expand Down Expand Up @@ -115,6 +115,7 @@
oauth_flow_service=None,
response_builder=None,
source_registry=None,
source_lifecycle=None,
sync_lifecycle=None,
sync_record_service=None,
temporal_workflow_service=None,
Expand All @@ -125,6 +126,7 @@
init_session_repo=init_session_repo or FakeOAuthInitSessionRepository(),
response_builder=response_builder or AsyncMock(),
source_registry=source_registry or MagicMock(),
source_lifecycle=source_lifecycle or AsyncMock(),
sync_lifecycle=sync_lifecycle or AsyncMock(),
sync_record_service=sync_record_service or AsyncMock(),
temporal_workflow_service=temporal_workflow_service or AsyncMock(),
Expand Down Expand Up @@ -229,6 +231,8 @@
assert "shell" in exc.value.detail.lower()

async def test_invalid_oauth2_token_fails_fast_with_400(self):
from airweave.domains.sources.exceptions import SourceValidationError

init_repo = FakeOAuthInitSessionRepository()
session = _init_session()
init_repo.seed_by_state("state-abc", session)
Expand All @@ -252,37 +256,29 @@
OAuth2TokenResponse(access_token="bad-token", token_type="bearer")
)

class _InvalidSource:
def set_logger(self, _logger):
return None

async def validate(self):
return False

class _SourceClass:
@staticmethod
async def create(access_token, config): # noqa: ARG004
return _InvalidSource()

registry = MagicMock()
registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github")
lifecycle = AsyncMock()
lifecycle.validate = AsyncMock(
side_effect=SourceValidationError("github", "validate() returned False")
)

svc = _service(
init_session_repo=init_repo,
organization_repo=org_repo,
sc_repo=sc_repo,
source_repo=source_repo,
oauth_flow_service=oauth_flow,
source_registry=registry,
source_lifecycle=lifecycle,
)
with pytest.raises(HTTPException) as exc:
await svc.complete_oauth2_callback(DB, state="state-abc", code="c")

assert exc.value.status_code == 400
assert "token" in exc.value.detail.lower()
assert "token validation failed" in exc.value.detail.lower()
assert all(call[0] != "mark_completed" for call in init_repo._calls)

async def test_validation_exception_fails_fast_with_400(self):
from airweave.domains.sources.exceptions import SourceCreationError

init_repo = FakeOAuthInitSessionRepository()
session = _init_session()
init_repo.seed_by_state("state-abc", session)
Expand All @@ -306,34 +302,22 @@
OAuth2TokenResponse(access_token="token", token_type="bearer")
)

class _BrokenSource:
def set_logger(self, _logger):
return None

async def validate(self):
raise RuntimeError("provider error")

class _SourceClass:
@staticmethod
async def create(access_token, config): # noqa: ARG004
return _BrokenSource()

registry = MagicMock()
registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github")
lifecycle = AsyncMock()
lifecycle.validate = AsyncMock(side_effect=SourceCreationError("github", "provider error"))

svc = _service(
init_session_repo=init_repo,
organization_repo=org_repo,
sc_repo=sc_repo,
source_repo=source_repo,
oauth_flow_service=oauth_flow,
source_registry=registry,
source_lifecycle=lifecycle,
)
with pytest.raises(HTTPException) as exc:
await svc.complete_oauth2_callback(DB, state="state-abc", code="c")

assert exc.value.status_code == 400
assert "validation failed" in exc.value.detail.lower()
assert "token validation failed" in exc.value.detail.lower()
assert all(call[0] != "mark_completed" for call in init_repo._calls)

async def test_happy_path_delegates_and_finalizes(self):
Expand Down Expand Up @@ -373,7 +357,10 @@
return _SourceOk()

registry = MagicMock()
registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github")
registry.get.return_value = SimpleNamespace(
source_class_ref=_SourceClass,
short_name="github",
)

svc = _service(
init_session_repo=init_repo,
Expand Down Expand Up @@ -493,7 +480,9 @@
)

oauth_flow = FakeOAuthFlowService()
oauth_flow.seed_oauth1_response(OAuth1TokenResponse(oauth_token="at", oauth_token_secret="as"))
oauth_flow.seed_oauth1_response(
OAuth1TokenResponse(oauth_token="at", oauth_token_secret="as"),
)
svc = _service(
init_session_repo=init_repo,
organization_repo=org_repo,
Expand Down Expand Up @@ -663,9 +652,9 @@
)
source_repo.seed("salesforce", source)

session = _init_session(short_name="salesforce")
_session = _init_session(short_name="salesforce") # noqa: F841

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable _session is not used.

token = SimpleNamespace(
_token = SimpleNamespace( # noqa: F841

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable _token is not used.
model_dump=lambda: {
"access_token": "tok",
"instance_url": "https://my.salesforce.com",
Expand Down Expand Up @@ -823,7 +812,9 @@
svc._collection_repo.get_by_readable_id = AsyncMock(
return_value=SimpleNamespace(id=uuid4(), readable_id="col-abc")
)
svc._sc_repo.update = AsyncMock(return_value=SimpleNamespace(id=uuid4(), connection_id=uuid4()))
svc._sc_repo.update = AsyncMock(
return_value=SimpleNamespace(id=uuid4(), connection_id=uuid4()),
)
svc._init_session_repo.mark_completed = AsyncMock()
svc._source_registry.get = MagicMock(
return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=True))
Expand Down Expand Up @@ -883,13 +874,17 @@
svc._collection_repo.get_by_readable_id = AsyncMock(
return_value=SimpleNamespace(id=uuid4(), readable_id="col-abc")
)
svc._sc_repo.update = AsyncMock(return_value=SimpleNamespace(id=uuid4(), connection_id=conn_id))
svc._sc_repo.update = AsyncMock(
return_value=SimpleNamespace(id=uuid4(), connection_id=conn_id),
)
svc._init_session_repo.mark_completed = AsyncMock()
svc._source_registry.get = MagicMock(
return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=False))
)
svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()])
svc._sync_lifecycle.provision_sync = AsyncMock(return_value=SimpleNamespace(sync_id=uuid4()))
svc._sync_lifecycle.provision_sync = AsyncMock(
return_value=SimpleNamespace(sync_id=uuid4()),
)

from airweave.domains.oauth import callback_service as callback_module

Expand Down Expand Up @@ -944,9 +939,7 @@

class TestFinalizeCallback:
async def test_no_sync_id_just_returns_response(self):
response = MagicMock(
id=uuid4(), short_name="github", readable_collection_id="col-abc"
)
response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc")
builder = AsyncMock()
builder.build_response = AsyncMock(return_value=response)
event_bus = AsyncMock()
Expand All @@ -961,9 +954,7 @@
builder.build_response.assert_awaited_once()

async def test_triggers_workflow_when_pending_job_exists(self):
response = MagicMock(
id=uuid4(), short_name="github", readable_collection_id="col-abc"
)
response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc")
builder = AsyncMock()
builder.build_response = AsyncMock(return_value=response)

Expand Down Expand Up @@ -1063,9 +1054,7 @@
temporal_svc.run_source_connection_workflow.assert_awaited_once()

async def test_no_pending_jobs_skips_workflow(self):
response = MagicMock(
id=uuid4(), short_name="github", readable_collection_id="col-abc"
)
response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc")
builder = AsyncMock()
builder.build_response = AsyncMock(return_value=response)

Expand Down Expand Up @@ -1098,9 +1087,7 @@
temporal_svc.run_source_connection_workflow.assert_not_awaited()

async def test_running_job_skips_workflow(self):
response = MagicMock(
id=uuid4(), short_name="github", readable_collection_id="col-abc"
)
response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc")
builder = AsyncMock()
builder.build_response = AsyncMock(return_value=response)

Expand Down Expand Up @@ -1351,7 +1338,7 @@
class TestTokenValidation:
async def test_validate_token_returns_early_when_source_missing(self):
svc = _service()
await svc._validate_oauth2_token_or_raise(source=None, access_token="x", ctx=_ctx())
await svc._validate_oauth2_token_or_raise(source=None, access_token="x")


# ---------------------------------------------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions backend/airweave/domains/sources/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ async def validate(

source_class = entry.source_class_ref

if entry.auth_config_ref:
if isinstance(credentials, str):
credentials = entry.auth_config_ref.model_validate(
{"access_token": credentials},
)
elif isinstance(credentials, dict):
credentials = entry.auth_config_ref.model_validate(credentials)

try:
source = await source_class.create(credentials, config=config)
except Exception as exc:
Expand Down
Loading
Loading