diff --git a/.cursor/rules/form-validation.mdc b/.cursor/rules/form-validation.mdc index d653ea510..03cff31cb 100644 --- a/.cursor/rules/form-validation.mdc +++ b/.cursor/rules/form-validation.mdc @@ -28,7 +28,7 @@ Minimal, non-intrusive validation that guides without blocking. No success messa #### Source-Specific Validators **GitHub** -- **`githubTokenValidation`**: Validates format (ghp_, github_pat_, or 40-char hex) +- **`githubTokenValidation`**: Validates format (ghp_, github_pat_, gho_, or 40-char hex) - **`repoNameValidation`**: Enforces owner/repo format (e.g., "airweave-ai/airweave") **Stripe** @@ -91,7 +91,7 @@ The `getAuthFieldValidation(fieldType: string, sourceShortName?: string)` functi - **API Keys & Tokens** - `api_key` → Generic API key validation (placeholder detection) - `token`, `access_token` → API key validation - - `personal_access_token` → GitHub token validation + - `personal_access_token`, `token` → GitHub token validation - **URLs** - `url`, `endpoint`, `base_url`, `cluster_url`, `uri` → URL validation (requires http:// or https://) diff --git a/.vscode/launch.json b/.vscode/launch.json index 9e267e189..d2b00f7ed 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,6 +10,10 @@ "--reload", "--reload-dir", "backend/airweave", + "--reload-exclude", + "local_storage", + "--reload-exclude", + "backend/local_storage", "--host", "127.0.0.1", "--port", diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 40a5cafaa..35611cacf 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -383,6 +383,7 @@ def create_container(settings: Settings) -> Container: init_session_repo=init_session_repo, response_builder=sync_deps["response_builder"], source_registry=source_deps["source_registry"], + source_lifecycle=source_deps["source_lifecycle_service"], sync_lifecycle=sync_deps["sync_lifecycle"], sync_record_service=sync_deps["sync_record_service"], temporal_workflow_service=sync_deps["temporal_workflow_service"], diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index 6cf62e368..71bd11621 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -35,7 +35,15 @@ ResponseBuilderProtocol, SourceConnectionRepositoryProtocol, ) -from airweave.domains.sources.protocols import SourceRegistryProtocol +from airweave.domains.sources.exceptions import ( + SourceCreationError, + SourceNotFoundError, + SourceValidationError, +) +from airweave.domains.sources.protocols import ( + SourceLifecycleServiceProtocol, + SourceRegistryProtocol, +) from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.syncs.protocols import ( SyncJobRepositoryProtocol, @@ -78,6 +86,7 @@ def __init__( init_session_repo: OAuthInitSessionRepositoryProtocol, response_builder: ResponseBuilderProtocol, source_registry: SourceRegistryProtocol, + source_lifecycle: SourceLifecycleServiceProtocol, sync_lifecycle: SyncLifecycleServiceProtocol, sync_record_service: SyncRecordServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, @@ -97,6 +106,7 @@ def __init__( self._init_session_repo = init_session_repo self._response_builder = response_builder self._source_registry = source_registry + self._source_lifecycle = source_lifecycle self._sync_lifecycle = sync_lifecycle self._sync_record_service = sync_record_service self._temporal_workflow_service = temporal_workflow_service @@ -181,7 +191,6 @@ async def complete_oauth2_callback( await self._validate_oauth2_token_or_raise( source=source, access_token=token_response.access_token, - ctx=ctx, ) source_conn = await self._complete_oauth2_connection( @@ -548,24 +557,18 @@ async def _validate_oauth2_token_or_raise( *, source: Source | None, access_token: str, - ctx: ApiContext, ) -> None: - """Validate OAuth2 token using source implementation; fail callback if invalid.""" + """Validate OAuth2 token via SourceLifecycleService.validate (create → validate).""" if not source: return try: - source_cls = self._source_registry.get(source.short_name).source_class_ref - - source_instance = await source_cls.create(access_token=access_token, config=None) - source_instance.set_logger(ctx.logger) - - if hasattr(source_instance, "validate"): - is_valid = await source_instance.validate() - if not is_valid: - raise HTTPException(status_code=400, detail="OAuth token is invalid") - except HTTPException: - raise + await self._source_lifecycle.validate( + source.short_name, + access_token, + ) + except (SourceNotFoundError, SourceCreationError, SourceValidationError) as e: + raise HTTPException(status_code=400, detail=f"Token validation failed: {e}") from e except Exception as e: raise HTTPException(status_code=400, detail=f"Token validation failed: {e}") from e diff --git a/backend/airweave/domains/oauth/oauth2_service.py b/backend/airweave/domains/oauth/oauth2_service.py index f035c8582..e127d1dde 100644 --- a/backend/airweave/domains/oauth/oauth2_service.py +++ b/backend/airweave/domains/oauth/oauth2_service.py @@ -726,6 +726,7 @@ async def _exchange_code( """ headers = { "Content-Type": integration_config.content_type, + "Accept": "application/json", } payload = { diff --git a/backend/airweave/domains/oauth/tests/test_callback_service.py b/backend/airweave/domains/oauth/tests/test_callback_service.py index dcac9c419..46090cbf2 100644 --- a/backend/airweave/domains/oauth/tests/test_callback_service.py +++ b/backend/airweave/domains/oauth/tests/test_callback_service.py @@ -24,6 +24,7 @@ FakeOAuthInitSessionRepository, FakeOAuthSourceRepository, ) +from airweave.domains.oauth.types import OAuth1TokenResponse from airweave.domains.organizations.fakes.repository import FakeOrganizationRepository from airweave.domains.source_connections.fakes.repository import FakeSourceConnectionRepository from airweave.domains.syncs.fakes.sync_job_repository import FakeSyncJobRepository @@ -32,9 +33,8 @@ from airweave.models.organization import Organization from airweave.models.source_connection import SourceConnection from airweave.platform.auth.schemas import OAuth2TokenResponse -from airweave.domains.oauth.types import OAuth1TokenResponse -from airweave.schemas.source_connection import AuthenticationMethod from airweave.schemas.organization import Organization as OrganizationSchema +from airweave.schemas.source_connection import AuthenticationMethod NOW = datetime.now(timezone.utc) ORG_ID = uuid4() @@ -115,6 +115,7 @@ def _service( oauth_flow_service=None, response_builder=None, source_registry=None, + source_lifecycle=None, sync_lifecycle=None, sync_record_service=None, temporal_workflow_service=None, @@ -125,6 +126,7 @@ def _service( init_session_repo=init_session_repo or FakeOAuthInitSessionRepository(), response_builder=response_builder or AsyncMock(), source_registry=source_registry or MagicMock(), + source_lifecycle=source_lifecycle or AsyncMock(), sync_lifecycle=sync_lifecycle or AsyncMock(), sync_record_service=sync_record_service or AsyncMock(), temporal_workflow_service=temporal_workflow_service or AsyncMock(), @@ -229,6 +231,8 @@ async def test_missing_source_conn_shell_raises_404(self): assert "shell" in exc.value.detail.lower() async def test_invalid_oauth2_token_fails_fast_with_400(self): + from airweave.domains.sources.exceptions import SourceValidationError + init_repo = FakeOAuthInitSessionRepository() session = _init_session() init_repo.seed_by_state("state-abc", session) @@ -252,20 +256,10 @@ async def test_invalid_oauth2_token_fails_fast_with_400(self): OAuth2TokenResponse(access_token="bad-token", token_type="bearer") ) - class _InvalidSource: - def set_logger(self, _logger): - return None - - async def validate(self): - return False - - class _SourceClass: - @staticmethod - async def create(access_token, config): # noqa: ARG004 - return _InvalidSource() - - registry = MagicMock() - registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github") + lifecycle = AsyncMock() + lifecycle.validate = AsyncMock( + side_effect=SourceValidationError("github", "validate() returned False") + ) svc = _service( init_session_repo=init_repo, @@ -273,16 +267,18 @@ async def create(access_token, config): # noqa: ARG004 sc_repo=sc_repo, source_repo=source_repo, oauth_flow_service=oauth_flow, - source_registry=registry, + source_lifecycle=lifecycle, ) with pytest.raises(HTTPException) as exc: await svc.complete_oauth2_callback(DB, state="state-abc", code="c") assert exc.value.status_code == 400 - assert "token" in exc.value.detail.lower() + assert "token validation failed" in exc.value.detail.lower() assert all(call[0] != "mark_completed" for call in init_repo._calls) async def test_validation_exception_fails_fast_with_400(self): + from airweave.domains.sources.exceptions import SourceCreationError + init_repo = FakeOAuthInitSessionRepository() session = _init_session() init_repo.seed_by_state("state-abc", session) @@ -306,20 +302,8 @@ async def test_validation_exception_fails_fast_with_400(self): OAuth2TokenResponse(access_token="token", token_type="bearer") ) - class _BrokenSource: - def set_logger(self, _logger): - return None - - async def validate(self): - raise RuntimeError("provider error") - - class _SourceClass: - @staticmethod - async def create(access_token, config): # noqa: ARG004 - return _BrokenSource() - - registry = MagicMock() - registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github") + lifecycle = AsyncMock() + lifecycle.validate = AsyncMock(side_effect=SourceCreationError("github", "provider error")) svc = _service( init_session_repo=init_repo, @@ -327,13 +311,13 @@ async def create(access_token, config): # noqa: ARG004 sc_repo=sc_repo, source_repo=source_repo, oauth_flow_service=oauth_flow, - source_registry=registry, + source_lifecycle=lifecycle, ) with pytest.raises(HTTPException) as exc: await svc.complete_oauth2_callback(DB, state="state-abc", code="c") assert exc.value.status_code == 400 - assert "validation failed" in exc.value.detail.lower() + assert "token validation failed" in exc.value.detail.lower() assert all(call[0] != "mark_completed" for call in init_repo._calls) async def test_happy_path_delegates_and_finalizes(self): @@ -373,7 +357,10 @@ async def create(access_token, config): # noqa: ARG004 return _SourceOk() registry = MagicMock() - registry.get.return_value = SimpleNamespace(source_class_ref=_SourceClass, short_name="github") + registry.get.return_value = SimpleNamespace( + source_class_ref=_SourceClass, + short_name="github", + ) svc = _service( init_session_repo=init_repo, @@ -493,7 +480,9 @@ async def get_by_short_name(_short_name): ) oauth_flow = FakeOAuthFlowService() - oauth_flow.seed_oauth1_response(OAuth1TokenResponse(oauth_token="at", oauth_token_secret="as")) + oauth_flow.seed_oauth1_response( + OAuth1TokenResponse(oauth_token="at", oauth_token_secret="as"), + ) svc = _service( init_session_repo=init_repo, organization_repo=org_repo, @@ -663,9 +652,9 @@ async def test_salesforce_extracts_instance_url(self): ) source_repo.seed("salesforce", source) - session = _init_session(short_name="salesforce") + _session = _init_session(short_name="salesforce") # noqa: F841 - token = SimpleNamespace( + _token = SimpleNamespace( # noqa: F841 model_dump=lambda: { "access_token": "tok", "instance_url": "https://my.salesforce.com", @@ -823,7 +812,9 @@ async def test_federated_source_skips_sync_provisioning(self): svc._collection_repo.get_by_readable_id = AsyncMock( return_value=SimpleNamespace(id=uuid4(), readable_id="col-abc") ) - svc._sc_repo.update = AsyncMock(return_value=SimpleNamespace(id=uuid4(), connection_id=uuid4())) + svc._sc_repo.update = AsyncMock( + return_value=SimpleNamespace(id=uuid4(), connection_id=uuid4()), + ) svc._init_session_repo.mark_completed = AsyncMock() svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=True)) @@ -883,13 +874,17 @@ async def test_non_federated_source_provisions_sync_with_cron_schedule(self): svc._collection_repo.get_by_readable_id = AsyncMock( return_value=SimpleNamespace(id=uuid4(), readable_id="col-abc") ) - svc._sc_repo.update = AsyncMock(return_value=SimpleNamespace(id=uuid4(), connection_id=conn_id)) + svc._sc_repo.update = AsyncMock( + return_value=SimpleNamespace(id=uuid4(), connection_id=conn_id), + ) svc._init_session_repo.mark_completed = AsyncMock() svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=False)) ) svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) - svc._sync_lifecycle.provision_sync = AsyncMock(return_value=SimpleNamespace(sync_id=uuid4())) + svc._sync_lifecycle.provision_sync = AsyncMock( + return_value=SimpleNamespace(sync_id=uuid4()), + ) from airweave.domains.oauth import callback_service as callback_module @@ -944,9 +939,7 @@ async def commit(self): class TestFinalizeCallback: async def test_no_sync_id_just_returns_response(self): - response = MagicMock( - id=uuid4(), short_name="github", readable_collection_id="col-abc" - ) + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") builder = AsyncMock() builder.build_response = AsyncMock(return_value=response) event_bus = AsyncMock() @@ -961,9 +954,7 @@ async def test_no_sync_id_just_returns_response(self): builder.build_response.assert_awaited_once() async def test_triggers_workflow_when_pending_job_exists(self): - response = MagicMock( - id=uuid4(), short_name="github", readable_collection_id="col-abc" - ) + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") builder = AsyncMock() builder.build_response = AsyncMock(return_value=response) @@ -1063,9 +1054,7 @@ async def test_triggers_workflow_when_pending_job_exists(self): temporal_svc.run_source_connection_workflow.assert_awaited_once() async def test_no_pending_jobs_skips_workflow(self): - response = MagicMock( - id=uuid4(), short_name="github", readable_collection_id="col-abc" - ) + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") builder = AsyncMock() builder.build_response = AsyncMock(return_value=response) @@ -1098,9 +1087,7 @@ async def test_no_pending_jobs_skips_workflow(self): temporal_svc.run_source_connection_workflow.assert_not_awaited() async def test_running_job_skips_workflow(self): - response = MagicMock( - id=uuid4(), short_name="github", readable_collection_id="col-abc" - ) + response = MagicMock(id=uuid4(), short_name="github", readable_collection_id="col-abc") builder = AsyncMock() builder.build_response = AsyncMock(return_value=response) @@ -1351,7 +1338,7 @@ async def test_pending_job_with_connection_executes_event_payload_path(self): class TestTokenValidation: async def test_validate_token_returns_early_when_source_missing(self): svc = _service() - await svc._validate_oauth2_token_or_raise(source=None, access_token="x", ctx=_ctx()) + await svc._validate_oauth2_token_or_raise(source=None, access_token="x") # --------------------------------------------------------------------------- diff --git a/backend/airweave/domains/sources/lifecycle.py b/backend/airweave/domains/sources/lifecycle.py index 04261661d..20c75791a 100644 --- a/backend/airweave/domains/sources/lifecycle.py +++ b/backend/airweave/domains/sources/lifecycle.py @@ -180,6 +180,14 @@ async def validate( source_class = entry.source_class_ref + if entry.auth_config_ref: + if isinstance(credentials, str): + credentials = entry.auth_config_ref.model_validate( + {"access_token": credentials}, + ) + elif isinstance(credentials, dict): + credentials = entry.auth_config_ref.model_validate(credentials) + try: source = await source_class.create(credentials, config=config) except Exception as exc: diff --git a/backend/airweave/domains/sources/tests/test_lifecycle.py b/backend/airweave/domains/sources/tests/test_lifecycle.py index f82e7431f..7ad5b1114 100644 --- a/backend/airweave/domains/sources/tests/test_lifecycle.py +++ b/backend/airweave/domains/sources/tests/test_lifecycle.py @@ -8,17 +8,17 @@ from dataclasses import dataclass, field from typing import Any, Optional from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID, uuid4 +from uuid import uuid4 import pytest +from airweave.core.exceptions import NotFoundException +from airweave.domains.auth_provider.fake import FakeAuthProviderRegistry +from airweave.domains.auth_provider.types import AuthProviderRegistryEntry from airweave.domains.connections.fakes.repository import FakeConnectionRepository from airweave.domains.credentials.fakes.repository import FakeIntegrationCredentialRepository from airweave.domains.oauth.fakes.oauth2_service import FakeOAuth2Service from airweave.domains.source_connections.fakes.repository import FakeSourceConnectionRepository -from airweave.core.exceptions import NotFoundException -from airweave.domains.auth_provider.fake import FakeAuthProviderRegistry -from airweave.domains.auth_provider.types import AuthProviderRegistryEntry from airweave.domains.sources.exceptions import ( SourceCreationError, SourceNotFoundError, @@ -31,7 +31,6 @@ from airweave.platform.auth_providers.auth_result import AuthProviderMode from airweave.platform.configs._base import Fields - # --------------------------------------------------------------------------- # Stub source classes # --------------------------------------------------------------------------- @@ -243,25 +242,40 @@ class ValidateCase: VALIDATE_TABLE = [ ValidateCase(id="happy-string-creds", short_name="github"), - ValidateCase(id="happy-dict-creds", short_name="github", - credentials={"access_token": "tok", "refresh_token": "ref"}), - ValidateCase(id="happy-with-config", short_name="github", - config={"repo": "owner/repo"}), - ValidateCase(id="not-in-registry", short_name="nonexistent", seed=False, - expect_error=SourceNotFoundError, - error_substring="nonexistent"), - ValidateCase(id="create-raises", short_name="bad_source", - source_class=_StubSourceCreateRaises, - expect_error=SourceCreationError, - error_substring="bad credentials format"), - ValidateCase(id="validate-raises", short_name="unreachable", - source_class=_StubSourceValidateRaises, - expect_error=SourceValidationError, - error_substring="validation raised"), - ValidateCase(id="validate-returns-false", short_name="invalid_creds", - source_class=_StubSourceValidateFalse, - expect_error=SourceValidationError, - error_substring="validate() returned False"), + ValidateCase( + id="happy-dict-creds", + short_name="github", + credentials={"access_token": "tok", "refresh_token": "ref"}, + ), + ValidateCase(id="happy-with-config", short_name="github", config={"repo": "owner/repo"}), + ValidateCase( + id="not-in-registry", + short_name="nonexistent", + seed=False, + expect_error=SourceNotFoundError, + error_substring="nonexistent", + ), + ValidateCase( + id="create-raises", + short_name="bad_source", + source_class=_StubSourceCreateRaises, + expect_error=SourceCreationError, + error_substring="bad credentials format", + ), + ValidateCase( + id="validate-raises", + short_name="unreachable", + source_class=_StubSourceValidateRaises, + expect_error=SourceValidationError, + error_substring="validation raised", + ), + ValidateCase( + id="validate-returns-false", + short_name="invalid_creds", + source_class=_StubSourceValidateFalse, + expect_error=SourceValidationError, + error_substring="validate() returned False", + ), ] @@ -313,6 +327,34 @@ async def test_validate_validation_error_attributes(): assert exc_info.value.reason == "validate() returned False" +@pytest.mark.asyncio +async def test_validate_converts_dict_credentials_via_auth_config_ref(): + """When auth_config_ref is set, validate() should model_validate dict credentials.""" + from pydantic import BaseModel + + class _StubAuthConfig(BaseModel): + token: str + + class _StubSourceWithAuth: + @classmethod + async def create(cls, credentials, config=None): + assert isinstance(credentials, _StubAuthConfig), ( + f"Expected _StubAuthConfig, got {type(credentials)}" + ) + instance = cls() + instance._credentials = credentials + return instance + + async def validate(self): + return True + + entry = _entry_with_class("auth_src", _StubSourceWithAuth) + object.__setattr__(entry, "auth_config_ref", _StubAuthConfig) + service = _make_service(source_entries=[entry]) + + await service.validate("auth_src", {"token": "test-tok"}) + + # =========================================================================== # _load_source_connection_data() — table-driven # =========================================================================== @@ -334,16 +376,32 @@ class LoadSCDataCase: LOAD_SC_DATA_TABLE = [ LoadSCDataCase(id="happy-path"), - LoadSCDataCase(id="sc-not-found", sc_exists=False, - expect_error=NotFoundException, error_match="not found"), - LoadSCDataCase(id="not-in-registry", source_in_registry=False, - expect_error=SourceNotFoundError, error_match="github"), - LoadSCDataCase(id="conn-not-found", conn_exists=False, - expect_error=NotFoundException, error_match="Connection not found"), - LoadSCDataCase(id="no-cred-no-auth-provider", has_cred_id=False, - expect_error=NotFoundException, error_match="no integration credential"), - LoadSCDataCase(id="auth-provider-skips-cred-check", - has_cred_id=False, readable_auth_provider_id="pipedream-123"), + LoadSCDataCase( + id="sc-not-found", sc_exists=False, expect_error=NotFoundException, error_match="not found" + ), + LoadSCDataCase( + id="not-in-registry", + source_in_registry=False, + expect_error=SourceNotFoundError, + error_match="github", + ), + LoadSCDataCase( + id="conn-not-found", + conn_exists=False, + expect_error=NotFoundException, + error_match="Connection not found", + ), + LoadSCDataCase( + id="no-cred-no-auth-provider", + has_cred_id=False, + expect_error=NotFoundException, + error_match="no integration credential", + ), + LoadSCDataCase( + id="auth-provider-skips-cred-check", + has_cred_id=False, + readable_auth_provider_id="pipedream-123", + ), LoadSCDataCase(id="preserves-config", config_fields={"repo": "o/r", "branch": "main"}), ] @@ -409,18 +467,26 @@ class AuthConfigRoutingCase: AUTH_CONFIG_ROUTING_TABLE = [ - AuthConfigRoutingCase(id="direct-token", access_token="tok-123", - expected_route="direct"), - AuthConfigRoutingCase(id="direct-token-beats-auth-provider", - access_token="tok", readable_auth_provider_id="pd-1", - auth_provider_config={"k": "v"}, expected_route="direct"), - AuthConfigRoutingCase(id="auth-provider", readable_auth_provider_id="pd-1", - auth_provider_config={"k": "v"}, - expected_route="auth_provider"), + AuthConfigRoutingCase(id="direct-token", access_token="tok-123", expected_route="direct"), + AuthConfigRoutingCase( + id="direct-token-beats-auth-provider", + access_token="tok", + readable_auth_provider_id="pd-1", + auth_provider_config={"k": "v"}, + expected_route="direct", + ), + AuthConfigRoutingCase( + id="auth-provider", + readable_auth_provider_id="pd-1", + auth_provider_config={"k": "v"}, + expected_route="auth_provider", + ), AuthConfigRoutingCase(id="database-fallthrough", expected_route="database"), - AuthConfigRoutingCase(id="auth-provider-id-but-no-config", - readable_auth_provider_id="pd-1", - expected_route="database"), + AuthConfigRoutingCase( + id="auth-provider-id-but-no-config", + readable_auth_provider_id="pd-1", + expected_route="database", + ), ] @@ -436,34 +502,47 @@ async def test_get_auth_configuration_routing(case: AuthConfigRoutingCase): if case.expected_route == "direct": result = await service._get_auth_configuration( - db=MagicMock(), source_connection_data=data, ctx=ctx, - logger=ctx.logger, access_token=case.access_token, + db=MagicMock(), + source_connection_data=data, + ctx=ctx, + logger=ctx.logger, + access_token=case.access_token, ) assert isinstance(result, AuthConfig) assert result.credentials == case.access_token assert result.auth_mode == AuthProviderMode.DIRECT elif case.expected_route == "auth_provider": - with patch.object(service, "_get_auth_provider_configuration", - new_callable=AsyncMock) as mock_ap: + with patch.object( + service, "_get_auth_provider_configuration", new_callable=AsyncMock + ) as mock_ap: mock_ap.return_value = AuthConfig( - credentials="ap", auth_mode=AuthProviderMode.DIRECT, - http_client_factory=None, auth_provider_instance=None, + credentials="ap", + auth_mode=AuthProviderMode.DIRECT, + http_client_factory=None, + auth_provider_instance=None, ) result = await service._get_auth_configuration( - db=MagicMock(), source_connection_data=data, ctx=ctx, - logger=ctx.logger, access_token=case.access_token, + db=MagicMock(), + source_connection_data=data, + ctx=ctx, + logger=ctx.logger, + access_token=case.access_token, ) mock_ap.assert_called_once() else: - with patch.object(service, "_get_database_credentials", - new_callable=AsyncMock) as mock_db: + with patch.object(service, "_get_database_credentials", new_callable=AsyncMock) as mock_db: mock_db.return_value = AuthConfig( - credentials="db", auth_mode=AuthProviderMode.DIRECT, - http_client_factory=None, auth_provider_instance=None, + credentials="db", + auth_mode=AuthProviderMode.DIRECT, + http_client_factory=None, + auth_provider_instance=None, ) result = await service._get_auth_configuration( - db=MagicMock(), source_connection_data=data, ctx=ctx, - logger=ctx.logger, access_token=case.access_token, + db=MagicMock(), + source_connection_data=data, + ctx=ctx, + logger=ctx.logger, + access_token=case.access_token, ) mock_db.assert_called_once() @@ -485,10 +564,18 @@ class DBCredCase: DB_CRED_TABLE = [ DBCredCase(id="happy-no-auth-config"), - DBCredCase(id="no-cred-id", has_cred_id=False, - expect_error=NotFoundException, error_match="no integration credential"), - DBCredCase(id="cred-not-found", cred_found=False, - expect_error=NotFoundException, error_match="credential not found"), + DBCredCase( + id="no-cred-id", + has_cred_id=False, + expect_error=NotFoundException, + error_match="no integration credential", + ), + DBCredCase( + id="cred-not-found", + cred_found=False, + expect_error=NotFoundException, + error_match="credential not found", + ), DBCredCase(id="with-auth-config-delegates", auth_config_class="StripeAuthConfig"), ] @@ -515,8 +602,9 @@ async def test_get_database_credentials(case: DBCredCase): elif case.auth_config_class: with ( patch("airweave.domains.sources.lifecycle.credentials") as mock_creds, - patch.object(service, "_handle_auth_config_credentials", - new_callable=AsyncMock) as mock_handle, + patch.object( + service, "_handle_auth_config_credentials", new_callable=AsyncMock + ) as mock_handle, ): mock_creds.decrypt.return_value = {"api_key": "sk"} mock_handle.return_value = {"api_key": "sk"} @@ -552,13 +640,18 @@ class HandleAuthConfigCase: HANDLE_AUTH_CONFIG_TABLE = [ - HandleAuthConfigCase(id="no-auth-config-ref-passthrough", - has_auth_config_ref=False, expect_raw_passthrough=True), - HandleAuthConfigCase(id="no-refresh-token-returns-validated", - has_auth_config_ref=True, has_refresh_token=False), - HandleAuthConfigCase(id="with-refresh-token-refreshes", - has_auth_config_ref=True, has_refresh_token=True, - refresh_token_value="ref-tok"), + HandleAuthConfigCase( + id="no-auth-config-ref-passthrough", has_auth_config_ref=False, expect_raw_passthrough=True + ), + HandleAuthConfigCase( + id="no-refresh-token-returns-validated", has_auth_config_ref=True, has_refresh_token=False + ), + HandleAuthConfigCase( + id="with-refresh-token-refreshes", + has_auth_config_ref=True, + has_refresh_token=True, + refresh_token_value="ref-tok", + ), ] @@ -586,8 +679,11 @@ async def test_handle_auth_config_credentials(case: HandleAuthConfigCase): data = _sc_data(short_name="src", auth_config_class="Cfg") result = await service._handle_auth_config_credentials( - db=MagicMock(), source_connection_data=data, - decrypted_credential=decrypted, ctx=_make_ctx(), connection_id=uuid4(), + db=MagicMock(), + source_connection_data=data, + decrypted_credential=decrypted, + ctx=_make_ctx(), + connection_id=uuid4(), ) if case.expect_raw_passthrough: @@ -619,9 +715,11 @@ async def test_handle_auth_config_oauth2_error_propagates(): with pytest.raises(RuntimeError, match="token server down"): await service._handle_auth_config_credentials( - db=MagicMock(), source_connection_data=data, + db=MagicMock(), + source_connection_data=data, decrypted_credential={"access_token": "x", "refresh_token": "ref"}, - ctx=_make_ctx(), connection_id=uuid4(), + ctx=_make_ctx(), + connection_id=uuid4(), ) @@ -641,19 +739,39 @@ class ProcessCredsCase: PROCESS_CREDS_TABLE = [ - ProcessCredsCase(id="passthrough-no-oauth-no-config", - raw_credentials="plain-token", expected="plain-token"), - ProcessCredsCase(id="oauth-extract-access-token", oauth_type="with_refresh", - raw_credentials={"access_token": "tok", "refresh_token": "r"}, - expected="tok"), - ProcessCredsCase(id="oauth-string-passthrough", oauth_type="access_only", - raw_credentials="already-string", expected="already-string"), - ProcessCredsCase(id="oauth-unexpected-format-passthrough", oauth_type="access_only", - raw_credentials=12345, expected=12345), - ProcessCredsCase(id="auth-config-dict-conversion", has_auth_config_ref=True, - raw_credentials={"api_key": "sk"}, expected="VALIDATED"), - ProcessCredsCase(id="auth-config-conversion-error", has_auth_config_ref=True, - raw_credentials={"bad": "x"}, expect_error=ValueError), + ProcessCredsCase( + id="passthrough-no-oauth-no-config", raw_credentials="plain-token", expected="plain-token" + ), + ProcessCredsCase( + id="oauth-extract-access-token", + oauth_type="with_refresh", + raw_credentials={"access_token": "tok", "refresh_token": "r"}, + expected="tok", + ), + ProcessCredsCase( + id="oauth-string-passthrough", + oauth_type="access_only", + raw_credentials="already-string", + expected="already-string", + ), + ProcessCredsCase( + id="oauth-unexpected-format-passthrough", + oauth_type="access_only", + raw_credentials=12345, + expected=12345, + ), + ProcessCredsCase( + id="auth-config-dict-conversion", + has_auth_config_ref=True, + raw_credentials={"api_key": "sk"}, + expected="VALIDATED", + ), + ProcessCredsCase( + id="auth-config-conversion-error", + has_auth_config_ref=True, + raw_credentials={"bad": "x"}, + expect_error=ValueError, + ), ] @@ -673,19 +791,24 @@ def test_process_credentials_for_source(case: ProcessCredsCase): service = _make_service(source_entries=[entry]) ctx = _make_ctx() - data = _sc_data(short_name="src", oauth_type=case.oauth_type, - auth_config_class="TestAuthConfig" if case.has_auth_config_ref else None) + data = _sc_data( + short_name="src", + oauth_type=case.oauth_type, + auth_config_class="TestAuthConfig" if case.has_auth_config_ref else None, + ) if case.expect_error: with pytest.raises(case.expect_error): service._process_credentials_for_source( raw_credentials=case.raw_credentials, - source_connection_data=data, logger=ctx.logger, + source_connection_data=data, + logger=ctx.logger, ) else: result = service._process_credentials_for_source( raw_credentials=case.raw_credentials, - source_connection_data=data, logger=ctx.logger, + source_connection_data=data, + logger=ctx.logger, ) assert result == case.expected @@ -709,10 +832,10 @@ class TokenManagerCase: TokenManagerCase(id="skip-proxy-mode", auth_mode=AuthProviderMode.PROXY), TokenManagerCase(id="skip-no-oauth-type", oauth_type=None), TokenManagerCase(id="skip-access-only-oauth", oauth_type="access_only"), - TokenManagerCase(id="happy-with-refresh", oauth_type="with_refresh", - expect_tm_set=True), - TokenManagerCase(id="happy-rotating-refresh", oauth_type="with_rotating_refresh", - expect_tm_set=True), + TokenManagerCase(id="happy-with-refresh", oauth_type="with_refresh", expect_tm_set=True), + TokenManagerCase( + id="happy-rotating-refresh", oauth_type="with_rotating_refresh", expect_tm_set=True + ), ] @@ -733,16 +856,26 @@ async def test_configure_token_manager(case: TokenManagerCase): with patch("airweave.domains.sources.lifecycle.TokenManager") as mock_tm: mock_tm.return_value = MagicMock() await SourceLifecycleService._configure_token_manager( - db=MagicMock(), source=source, source_connection_data=data, - source_credentials="tok", ctx=ctx, logger=ctx.logger, - access_token=case.access_token, auth_config=auth_config, + db=MagicMock(), + source=source, + source_connection_data=data, + source_credentials="tok", + ctx=ctx, + logger=ctx.logger, + access_token=case.access_token, + auth_config=auth_config, ) source.set_token_manager.assert_called_once() else: await SourceLifecycleService._configure_token_manager( - db=MagicMock(), source=source, source_connection_data=data, - source_credentials="tok", ctx=ctx, logger=ctx.logger, - access_token=case.access_token, auth_config=auth_config, + db=MagicMock(), + source=source, + source_connection_data=data, + source_credentials="tok", + ctx=ctx, + logger=ctx.logger, + access_token=case.access_token, + auth_config=auth_config, ) assert source._token_manager is None @@ -769,8 +902,10 @@ def test_sets_when_present(self): source = MagicMock() factory = MagicMock() ac = AuthConfig( - credentials="x", http_client_factory=factory, - auth_provider_instance=None, auth_mode=AuthProviderMode.DIRECT, + credentials="x", + http_client_factory=factory, + auth_provider_instance=None, + auth_mode=AuthProviderMode.DIRECT, ) SourceLifecycleService._configure_http_client_factory(source, ac) source.set_http_client_factory.assert_called_once_with(factory) @@ -778,8 +913,10 @@ def test_sets_when_present(self): def test_noop_when_none(self): source = MagicMock() ac = AuthConfig( - credentials="x", http_client_factory=None, - auth_provider_instance=None, auth_mode=AuthProviderMode.DIRECT, + credentials="x", + http_client_factory=None, + auth_provider_instance=None, + auth_mode=AuthProviderMode.DIRECT, ) SourceLifecycleService._configure_http_client_factory(source, ac) source.set_http_client_factory.assert_not_called() @@ -824,8 +961,11 @@ def test_wrap_source_with_airweave_client(case: WrapClientCase): ctx = _make_ctx() SourceLifecycleService._wrap_source_with_airweave_client( - source=source, source_short_name="src", - source_connection_id=uuid4(), ctx=ctx, logger=ctx.logger, + source=source, + source_short_name="src", + source_connection_id=uuid4(), + ctx=ctx, + logger=ctx.logger, ) assert source._http_client_factory is not None @@ -847,14 +987,30 @@ class MergeConfigCase: MERGE_CONFIG_TABLE = [ - MergeConfigCase(id="new-key-added", existing={}, - provider={"k": "v"}, expected_key="k", expected_value="v"), - MergeConfigCase(id="user-value-preserved", existing={"k": "user"}, - provider={"k": "provider"}, expected_key="k", expected_value="user"), - MergeConfigCase(id="none-value-overwritten", existing={"k": None}, - provider={"k": "provider"}, expected_key="k", expected_value="provider"), - MergeConfigCase(id="null-config-fields", existing=None, - provider={"k": "v"}, expected_key="k", expected_value="v"), + MergeConfigCase( + id="new-key-added", existing={}, provider={"k": "v"}, expected_key="k", expected_value="v" + ), + MergeConfigCase( + id="user-value-preserved", + existing={"k": "user"}, + provider={"k": "provider"}, + expected_key="k", + expected_value="user", + ), + MergeConfigCase( + id="none-value-overwritten", + existing={"k": None}, + provider={"k": "provider"}, + expected_key="k", + expected_value="provider", + ), + MergeConfigCase( + id="null-config-fields", + existing=None, + provider={"k": "v"}, + expected_key="k", + expected_value="v", + ), ] @@ -885,11 +1041,17 @@ class ConfigMappingCase: ConfigMappingCase(id="no-short-name", short_name=None), ConfigMappingCase(id="not-in-registry", short_name="unknown"), ConfigMappingCase(id="no-config-ref", short_name="src", in_registry=True), - ConfigMappingCase(id="config-no-auth-fields", short_name="src", - in_registry=True, has_config_ref=True), - ConfigMappingCase(id="with-auth-field", short_name="src", - in_registry=True, has_config_ref=True, has_auth_field=True, - expected={"org": "org_slug"}), + ConfigMappingCase( + id="config-no-auth-fields", short_name="src", in_registry=True, has_config_ref=True + ), + ConfigMappingCase( + id="with-auth-field", + short_name="src", + in_registry=True, + has_config_ref=True, + has_auth_field=True, + expected={"org": "org_slug"}, + ), ] @@ -934,14 +1096,30 @@ class AuthProviderInstanceCase: AUTH_PROVIDER_INSTANCE_TABLE = [ AuthProviderInstanceCase(id="happy-path"), - AuthProviderInstanceCase(id="conn-not-found", conn_found=False, - expect_error=NotFoundException, error_match="readable_id"), - AuthProviderInstanceCase(id="no-cred-on-conn", has_cred_id=False, - expect_error=NotFoundException, error_match="no integration credential"), - AuthProviderInstanceCase(id="cred-not-found", cred_found=False, - expect_error=NotFoundException, error_match="credential not found"), - AuthProviderInstanceCase(id="not-in-ap-registry", in_ap_registry=False, - expect_error=NotFoundException, error_match="not found in registry"), + AuthProviderInstanceCase( + id="conn-not-found", + conn_found=False, + expect_error=NotFoundException, + error_match="readable_id", + ), + AuthProviderInstanceCase( + id="no-cred-on-conn", + has_cred_id=False, + expect_error=NotFoundException, + error_match="no integration credential", + ), + AuthProviderInstanceCase( + id="cred-not-found", + cred_found=False, + expect_error=NotFoundException, + error_match="credential not found", + ), + AuthProviderInstanceCase( + id="not-in-ap-registry", + in_ap_registry=False, + expect_error=NotFoundException, + error_match="not found in registry", + ), ] @@ -968,11 +1146,13 @@ class _StubProvider: ap_entries = [] if case.in_ap_registry: - ap_entries.append(_make_auth_provider_entry( - short_name=conn.short_name, provider_class_ref=_StubProvider)) + ap_entries.append( + _make_auth_provider_entry(short_name=conn.short_name, provider_class_ref=_StubProvider) + ) - service = _make_service(conn_repo=conn_repo, cred_repo=cred_repo, - auth_provider_entries=ap_entries) + service = _make_service( + conn_repo=conn_repo, cred_repo=cred_repo, auth_provider_entries=ap_entries + ) ctx = _make_ctx() if case.expect_error: @@ -980,15 +1160,21 @@ class _StubProvider: mock_creds.decrypt.return_value = {"token": "d"} with pytest.raises(case.expect_error, match=case.error_match): await service._create_auth_provider_instance( - db=MagicMock(), readable_auth_provider_id="pd-1", - auth_provider_config={}, ctx=ctx, logger=ctx.logger, + db=MagicMock(), + readable_auth_provider_id="pd-1", + auth_provider_config={}, + ctx=ctx, + logger=ctx.logger, ) else: with patch("airweave.domains.sources.lifecycle.credentials") as mock_creds: mock_creds.decrypt.return_value = {"token": "d"} result = await service._create_auth_provider_instance( - db=MagicMock(), readable_auth_provider_id="pd-1", - auth_provider_config={"env": "prd"}, ctx=ctx, logger=ctx.logger, + db=MagicMock(), + readable_auth_provider_id="pd-1", + auth_provider_config={"env": "prd"}, + ctx=ctx, + logger=ctx.logger, ) assert result is mock_provider_instance _StubProvider.create.assert_called_once() @@ -1028,15 +1214,18 @@ async def test_create(case: CreateCase): cred_repo = FakeIntegrationCredentialRepository() cred_repo.seed(conn.integration_credential_id, cred) - service = _make_service(source_entries=[entry], sc_repo=sc_repo, - conn_repo=conn_repo, cred_repo=cred_repo) + service = _make_service( + source_entries=[entry], sc_repo=sc_repo, conn_repo=conn_repo, cred_repo=cred_repo + ) ctx = _make_ctx() with patch("airweave.domains.sources.lifecycle.credentials") as mock_creds: mock_creds.decrypt.return_value = {"access_token": "tok"} source = await service.create( - db=MagicMock(), source_connection_id=sc.id, ctx=ctx, + db=MagicMock(), + source_connection_id=sc.id, + ctx=ctx, access_token=case.access_token, ) diff --git a/backend/airweave/platform/auth/oauth2_service.py b/backend/airweave/platform/auth/oauth2_service.py index 318306072..d853b4910 100644 --- a/backend/airweave/platform/auth/oauth2_service.py +++ b/backend/airweave/platform/auth/oauth2_service.py @@ -891,6 +891,7 @@ async def _exchange_code( """ headers = { "Content-Type": integration_config.content_type, + "Accept": "application/json", } payload = { diff --git a/backend/airweave/platform/auth/yaml/dev.integrations.yaml b/backend/airweave/platform/auth/yaml/dev.integrations.yaml index c09ef835f..1cd769dfa 100644 --- a/backend/airweave/platform/auth/yaml/dev.integrations.yaml +++ b/backend/airweave/platform/auth/yaml/dev.integrations.yaml @@ -74,6 +74,17 @@ integrations: token_access_type: "offline" github: + oauth_type: "access_only" + url: "https://github.com/login/oauth/authorize" + backend_url: "https://github.com/login/oauth/access_token" + grant_type: "authorization_code" + client_id: "Ov23li1erGjcd7Zhju1L" + client_secret: "149d080e175ddb08c17023d1922e8b55384615ff" + content_type: "application/x-www-form-urlencoded" + client_credential_location: "body" + scope: "repo read:user" + additional_frontend_params: + accept: "json" gitlab: oauth_type: "with_refresh" diff --git a/backend/airweave/platform/auth/yaml/self-hosted.integrations.yaml b/backend/airweave/platform/auth/yaml/self-hosted.integrations.yaml index dd5d0791d..13c338c06 100644 --- a/backend/airweave/platform/auth/yaml/self-hosted.integrations.yaml +++ b/backend/airweave/platform/auth/yaml/self-hosted.integrations.yaml @@ -75,6 +75,17 @@ integrations: token_access_type: "offline" github: + oauth_type: "access_only" + url: "https://github.com/login/oauth/authorize" + backend_url: "https://github.com/login/oauth/access_token" + grant_type: "authorization_code" + client_id: "placeholder_client_id" + client_secret: "placeholder_client_secret" + content_type: "application/x-www-form-urlencoded" + client_credential_location: "body" + scope: "repo read:user" + additional_frontend_params: + accept: "json" gitlab: oauth_type: "with_refresh" diff --git a/backend/airweave/platform/auth_providers/composio.py b/backend/airweave/platform/auth_providers/composio.py index cc401eff9..3910280ed 100644 --- a/backend/airweave/platform/auth_providers/composio.py +++ b/backend/airweave/platform/auth_providers/composio.py @@ -37,8 +37,7 @@ class ComposioAuthProvider(BaseAuthProvider): # Key: Airweave field name, Value: Composio field name FIELD_NAME_MAPPING = { "api_key": "generic_api_key", # Stripe and other API key sources - "personal_access_token": "access_token", # GitHub PAT mapping - # Add more mappings as needed + "token": "access_token", # GitHub unified token field } # Mapping of Airweave source short names to Composio toolkit slugs diff --git a/backend/airweave/platform/auth_providers/pipedream.py b/backend/airweave/platform/auth_providers/pipedream.py index 03b5025a2..62ca10155 100644 --- a/backend/airweave/platform/auth_providers/pipedream.py +++ b/backend/airweave/platform/auth_providers/pipedream.py @@ -83,8 +83,7 @@ class PipedreamAuthProvider(BaseAuthProvider): "refresh_token": "oauth_refresh_token", "client_id": "oauth_client_id", "client_secret": "oauth_client_secret", - "personal_access_token": "oauth_access_token", # GitHub PAT mapping - # Add more mappings as discovered + "token": "oauth_access_token", # GitHub unified token field } # Mapping of Airweave source short names to Pipedream app names diff --git a/backend/airweave/platform/configs/auth.py b/backend/airweave/platform/configs/auth.py index 49998fb9b..7b0ce0263 100644 --- a/backend/airweave/platform/configs/auth.py +++ b/backend/airweave/platform/configs/auth.py @@ -456,23 +456,37 @@ def validate_host(cls, v: str) -> str: class GitHubAuthConfig(AuthConfig): - """GitHub authentication credentials schema.""" + """GitHub authentication credentials schema. - personal_access_token: str = Field( + Accepts credentials from either the Direct (PAT) or OAuth browser flow. + A ``model_validator(before)`` normalises ``personal_access_token`` and + ``access_token`` inputs into the single ``token`` field. + """ + + token: str = Field( title="Personal Access Token", description="GitHub PAT with read rights (code, contents, metadata) to the repository", - min_length=4, ) - @field_validator("personal_access_token") + @model_validator(mode="before") @classmethod - def validate_personal_access_token(cls, v: str) -> str: - """Validate GitHub personal access token format.""" - if not v or not v.strip(): - raise ValueError("Personal access token is required") + def _normalise_token(cls, data: Any) -> Any: + """Map legacy / OAuth credential keys into the canonical ``token`` field.""" + if isinstance(data, dict) and "token" not in data: + t = data.pop("personal_access_token", None) or data.pop("access_token", None) + if t: + data["token"] = t + return data + + @field_validator("token", mode="before") + @classmethod + def _validate_token_format(cls, v: str) -> str: + """Validate GitHub token format.""" v = v.strip() - # GitHub classic tokens start with ghp_, fine-grained tokens start with github_pat_ - # Also allow legacy tokens (40 char hex) + if not v: + raise ValueError("Token must not be empty") + # ghp_ = classic PAT, github_pat_ = fine-grained PAT, + # gho_ = OAuth app token, 40-char hex = legacy token if not ( v.startswith("ghp_") or v.startswith("github_pat_") @@ -480,7 +494,7 @@ def validate_personal_access_token(cls, v: str) -> str: or (len(v) == 40 and all(c in "0123456789abcdef" for c in v.lower())) ): raise ValueError( - "Invalid token format. Expected format: " + "Invalid token format. Expected: " "ghp_... or github_pat_... or gho_... or 40-character hex" ) return v diff --git a/backend/airweave/platform/configs/tests/__init__.py b/backend/airweave/platform/configs/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/platform/configs/tests/test_github_auth.py b/backend/airweave/platform/configs/tests/test_github_auth.py new file mode 100644 index 000000000..ff9b07a1e --- /dev/null +++ b/backend/airweave/platform/configs/tests/test_github_auth.py @@ -0,0 +1,59 @@ +"""Unit tests for GitHubAuthConfig token normalisation and validation.""" + +import pytest +from pydantic import ValidationError + +from airweave.platform.configs.auth import GitHubAuthConfig + + +class TestNormaliseToken: + def test_pat_field_maps_to_token(self): + cfg = GitHubAuthConfig.model_validate({"personal_access_token": "ghp_abc123def456"}) + assert cfg.token == "ghp_abc123def456" + + def test_access_token_field_maps_to_token(self): + cfg = GitHubAuthConfig.model_validate({"access_token": "gho_abc123def456"}) + assert cfg.token == "gho_abc123def456" + + def test_canonical_token_field_passes_through(self): + cfg = GitHubAuthConfig.model_validate({"token": "ghp_abc123def456"}) + assert cfg.token == "ghp_abc123def456" + + def test_missing_all_fields_raises(self): + with pytest.raises(ValidationError): + GitHubAuthConfig.model_validate({}) + + def test_empty_string_raises(self): + with pytest.raises(ValidationError): + GitHubAuthConfig.model_validate({"token": ""}) + + def test_whitespace_only_raises(self): + with pytest.raises(ValidationError): + GitHubAuthConfig.model_validate({"token": " "}) + + +class TestTokenFormatValidation: + def test_classic_pat(self): + cfg = GitHubAuthConfig(token="ghp_abcdef1234567890abcdef1234567890abcdef") + assert cfg.token.startswith("ghp_") + + def test_fine_grained_pat(self): + cfg = GitHubAuthConfig(token="github_pat_abcdef1234567890") + assert cfg.token.startswith("github_pat_") + + def test_oauth_app_token(self): + cfg = GitHubAuthConfig(token="gho_abcdef1234567890") + assert cfg.token.startswith("gho_") + + def test_legacy_hex_token(self): + hex_token = "a" * 40 + cfg = GitHubAuthConfig(token=hex_token) + assert cfg.token == hex_token + + def test_invalid_format_raises(self): + with pytest.raises(ValidationError, match="Invalid token format"): + GitHubAuthConfig(token="bad-token-format") + + def test_strips_whitespace(self): + cfg = GitHubAuthConfig(token=" ghp_abc123def456 ") + assert cfg.token == "ghp_abc123def456" diff --git a/backend/airweave/platform/sources/github.py b/backend/airweave/platform/sources/github.py index 69c6e293f..00a8f67d4 100644 --- a/backend/airweave/platform/sources/github.py +++ b/backend/airweave/platform/sources/github.py @@ -31,14 +31,18 @@ get_language_for_extension, is_text_file, ) -from airweave.schemas.source_connection import AuthenticationMethod +from airweave.schemas.source_connection import AuthenticationMethod, OAuthType @source( name="GitHub", short_name="github", - auth_methods=[AuthenticationMethod.DIRECT, AuthenticationMethod.AUTH_PROVIDER], - oauth_type=None, + auth_methods=[ + AuthenticationMethod.OAUTH_BROWSER, + AuthenticationMethod.DIRECT, + AuthenticationMethod.AUTH_PROVIDER, + ], + oauth_type=OAuthType.ACCESS_ONLY, auth_config_class=GitHubAuthConfig, config_class=GitHubConfig, labels=["Code"], @@ -92,30 +96,26 @@ def validate_cursor_field(self, cursor_field: str) -> None: @classmethod async def create( - cls, credentials: GitHubAuthConfig, config: Optional[Dict[str, Any]] = None + cls, + credentials: GitHubAuthConfig, + config: Optional[Dict[str, Any]] = None, ) -> "GitHubSource": """Create a new source instance with authentication. Args: - credentials: GitHubAuthConfig instance containing authentication details - config: Optional source configuration parameters + credentials: GitHubAuthConfig with a validated token. + config: Source configuration (must include ``repo_name``). Returns: - Configured GitHub source instance + Configured GitHub source instance. """ instance = cls() + instance.token = credentials.token - instance.personal_access_token = credentials.personal_access_token - - # Repository name is always read from config (source configuration) - if not config or "repo_name" not in config: - raise ValueError("Repository name must be specified in source configuration") - - instance.repo_name = config["repo_name"] - - instance.branch = config.get("branch", None) - - instance.max_file_size = config.get("max_file_size", 10 * 1024 * 1024) + if config and "repo_name" in config: + instance.repo_name = config["repo_name"] + instance.branch = config.get("branch", None) + instance.max_file_size = config.get("max_file_size", 10 * 1024 * 1024) return instance @@ -139,7 +139,7 @@ async def _get_with_auth( JSON response """ headers = { - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "X-GitHub-Api-Version": "2022-11-28", } @@ -172,7 +172,7 @@ async def _get_paginated_results( while True: params["page"] = page headers = { - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "X-GitHub-Api-Version": "2022-11-28", } @@ -908,12 +908,12 @@ async def generate_entities(self) -> AsyncGenerator[BaseEntity, None]: async def validate(self) -> bool: """Verify GitHub PAT and repo/branch access with lightweight pings.""" - if not getattr(self, "personal_access_token", None): - self.logger.error("GitHub validation failed: missing personal_access_token.") + if not getattr(self, "token", None): + self.logger.error("GitHub validation failed: missing token.") return False headers = { - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "X-GitHub-Api-Version": "2022-11-28", } diff --git a/backend/entrypoint.sh b/backend/entrypoint.sh index e31f23b0e..a9167560f 100644 --- a/backend/entrypoint.sh +++ b/backend/entrypoint.sh @@ -53,4 +53,4 @@ cd /app && poetry run alembic upgrade heads # Start application echo "Starting application..." -poetry run uvicorn airweave.main:app --host 0.0.0.0 --port 8001 --reload +poetry run uvicorn airweave.main:app --host 0.0.0.0 --port 8001 --reload --reload-exclude 'local_storage' --reload-exclude 'backend/local_storage' diff --git a/frontend/src/components/creation-views/SourceConfigView.tsx b/frontend/src/components/creation-views/SourceConfigView.tsx index 867458be3..29c351329 100644 --- a/frontend/src/components/creation-views/SourceConfigView.tsx +++ b/frontend/src/components/creation-views/SourceConfigView.tsx @@ -151,29 +151,26 @@ export const SourceConfigView: React.FC = ({ humanReadabl return sourceDetails?.auth_methods?.includes('direct'); }; - // Determine available auth methods based on source + // Determine available auth methods based on source, preserving backend ordering const getAvailableAuthMethods = (): AuthMode[] => { if (!sourceDetails || !sourceDetails.auth_methods) return []; const methods: AuthMode[] = []; - // Check for direct auth (API keys, passwords, config) - if (sourceDetails.auth_methods.includes('direct')) { - methods.push('direct_auth'); - } - - // Check for OAuth browser flow - if (sourceDetails.auth_methods.includes('oauth_browser') || - sourceDetails.auth_methods.includes('oauth_token')) { - methods.push('oauth2'); - } - - // Add external provider if any are connected and source supports it - if (authProviderConnections.length > 0 && - sourceDetails.auth_methods.includes('auth_provider') && - sourceDetails.supported_auth_providers && - sourceDetails.supported_auth_providers.length > 0) { - methods.push('external_provider'); + for (const method of sourceDetails.auth_methods) { + if (method === 'direct' && !methods.includes('direct_auth')) { + methods.push('direct_auth'); + } else if ((method === 'oauth_browser' || method === 'oauth_token') && !methods.includes('oauth2')) { + methods.push('oauth2'); + } else if ( + method === 'auth_provider' && + !methods.includes('external_provider') && + authProviderConnections.length > 0 && + sourceDetails.supported_auth_providers && + sourceDetails.supported_auth_providers.length > 0 + ) { + methods.push('external_provider'); + } } return methods; @@ -224,25 +221,16 @@ export const SourceConfigView: React.FC = ({ humanReadabl fetchSourceDetails(); }, [selectedSource]); - // Set default auth mode based on available methods - // This runs after source details and auth providers are loaded + // Set default auth mode to the first available method (respects backend ordering) useEffect(() => { if (!sourceDetails || !sourceDetails.auth_methods || authMode) return; - // Determine default auth mode based on available methods - if (sourceDetails.auth_methods.includes('direct')) { - // Prefer direct auth if available - setAuthMode('direct_auth'); - } else if (sourceDetails.auth_methods.includes('oauth_browser')) { - // Then OAuth - setAuthMode('oauth2'); - // Auto-enable custom credentials for sources that require BYOC - if (sourceDetails.requires_byoc) { + const available = getAvailableAuthMethods(); + if (available.length > 0) { + setAuthMode(available[0]); + if (available[0] === 'oauth2' && sourceDetails.requires_byoc) { setUseOwnCredentials(true); } - } else if (authProviderConnections.length > 0 && sourceDetails.auth_methods.includes('auth_provider')) { - // Auth provider only if providers are connected - setAuthMode('external_provider'); } }, [sourceDetails, authProviderConnections, authMode, setAuthMode]); diff --git a/frontend/src/lib/validation/rules.ts b/frontend/src/lib/validation/rules.ts index 6c945014b..41bcc02c8 100644 --- a/frontend/src/lib/validation/rules.ts +++ b/frontend/src/lib/validation/rules.ts @@ -669,12 +669,13 @@ export const githubTokenValidation: FieldValidation = { // Check for valid GitHub token formats const isClassicToken = trimmed.startsWith('ghp_'); const isFineGrainedToken = trimmed.startsWith('github_pat_'); + const isOAuthToken = trimmed.startsWith('gho_'); const isLegacyToken = trimmed.length === 40 && /^[0-9a-fA-F]+$/.test(trimmed); - if (!isClassicToken && !isFineGrainedToken && !isLegacyToken) { + if (!isClassicToken && !isFineGrainedToken && !isOAuthToken && !isLegacyToken) { return { isValid: false, - hint: 'GitHub token should start with "ghp_" (classic) or "github_pat_" (fine-grained)', + hint: 'GitHub token should start with "ghp_" (classic), "github_pat_" (fine-grained), or "gho_" (OAuth)', severity: 'warning' }; } diff --git a/monke/auth/broker.py b/monke/auth/broker.py index ab205e922..9e240c300 100644 --- a/monke/auth/broker.py +++ b/monke/auth/broker.py @@ -182,4 +182,7 @@ async def get_credentials( } creds = {k: v for k, v in creds.items() if k in allowed} + if "token" not in creds and "access_token" in creds: + creds["token"] = creds["access_token"] + return creds diff --git a/monke/bongos/github.py b/monke/bongos/github.py index 492e621b2..d16b03b78 100644 --- a/monke/bongos/github.py +++ b/monke/bongos/github.py @@ -23,22 +23,16 @@ def __init__(self, credentials: Dict[str, Any], **kwargs): """Initialize the GitHub bongo. Args: - credentials: GitHub credentials with personal_access_token + credentials: GitHub credentials with token **kwargs: Additional configuration including repo_name (required), entity_count, file_types """ super().__init__(credentials) - # GitHub authentication - support both direct and Composio auth - self.personal_access_token = ( - credentials.get("personal_access_token") # Direct auth - or credentials.get("access_token") # Composio OAuth - or credentials.get("token") # Alternative token field - ) - - if not self.personal_access_token: + self.token = credentials.get("token") + if not self.token: available_fields = list(credentials.keys()) raise ValueError( - f"Missing GitHub authentication. Expected 'personal_access_token' (direct) or " - f"'access_token' (Composio). Available fields: {available_fields}" + f"Missing GitHub authentication. Expected 'token'. " + f"Available fields: {available_fields}" ) # repo_name is now in config_fields (kwargs) after migration @@ -270,7 +264,7 @@ async def _cleanup_orphaned_test_files(self, stats: Dict[str, Any]): response = await client.get( f"https://api.github.com/repos/{self.repo_name}/contents/", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", }, params={"ref": self.branch}, @@ -314,7 +308,7 @@ async def delete_orphaned_file(file_info): "DELETE", f"https://api.github.com/repos/{self.repo_name}/contents/{file_info['path']}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "Content-Type": "application/json", }, @@ -364,7 +358,7 @@ async def _create_test_file(self, filename: str, content: str) -> Dict[str, Any] self.logger.info(f"🔍 Creating file: {filename}") self.logger.info(f" Repository: {self.repo_name}") - self.logger.info(f" Token: {self.personal_access_token[:8]}...") + self.logger.info(f" Token: {self.token[:8]}...") async with httpx.AsyncClient() as client: # First check if file exists to get current SHA @@ -373,7 +367,7 @@ async def _create_test_file(self, filename: str, content: str) -> Dict[str, Any] check_response = await client.get( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", }, params={"ref": self.branch}, @@ -420,7 +414,7 @@ async def _create_test_file(self, filename: str, content: str) -> Dict[str, Any] response = await client.put( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "Content-Type": "application/json", }, @@ -455,7 +449,7 @@ async def _update_test_file( response = await client.put( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "Content-Type": "application/json", }, @@ -508,7 +502,7 @@ async def _get_file_sha(self, filename: str) -> Optional[str]: response = await client.get( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", }, params={"ref": self.branch}, @@ -533,7 +527,7 @@ async def _delete_test_file(self, filename: str, sha: str): "DELETE", f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "Content-Type": "application/json", }, @@ -558,7 +552,7 @@ async def _verify_file_deleted(self, filename: str) -> bool: response = await client.get( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", }, params={"ref": self.branch}, @@ -589,7 +583,7 @@ async def _force_delete_file(self, filename: str): response = await client.get( f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", }, params={"ref": self.branch}, @@ -604,7 +598,7 @@ async def _force_delete_file(self, filename: str): "DELETE", f"https://api.github.com/repos/{self.repo_name}/contents/{filename}", headers={ - "Authorization": f"token {self.personal_access_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", "Content-Type": "application/json", },