diff --git a/backend/airweave/domains/oauth/fakes/flow_service.py b/backend/airweave/domains/oauth/fakes/flow_service.py index ba67d52db..d6d51fb6d 100644 --- a/backend/airweave/domains/oauth/fakes/flow_service.py +++ b/backend/airweave/domains/oauth/fakes/flow_service.py @@ -209,6 +209,6 @@ async def create_proxy_url( self._calls.append(("create_proxy_url", provider_auth_url)) return ( "https://api.example.com/source-connections/authorize/abc12345", - datetime.now(timezone.utc) + timedelta(hours=24), + datetime.now(timezone.utc) + timedelta(minutes=5), uuid4(), ) diff --git a/backend/airweave/domains/oauth/flow_service.py b/backend/airweave/domains/oauth/flow_service.py index 3cc49d39a..b4353a4ca 100644 --- a/backend/airweave/domains/oauth/flow_service.py +++ b/backend/airweave/domains/oauth/flow_service.py @@ -4,7 +4,7 @@ Does NOT know about source connections, credentials, or syncs. """ -from datetime import datetime, timedelta, timezone +from datetime import datetime from typing import Any, Dict, Optional, Tuple from uuid import UUID @@ -22,7 +22,8 @@ ) from airweave.domains.oauth.types import OAuth1TokenResponse, OAuthBrowserInitiationResult from airweave.models.connection_init_session import ConnectionInitSession, ConnectionInitStatus -from airweave.platform.auth.schemas import OAuth1Settings, OAuth2TokenResponse +from airweave.models.redirect_session import RedirectSession +from airweave.platform.auth.schemas import OAuth1Settings, OAuth2Settings, OAuth2TokenResponse from airweave.platform.auth.settings import IntegrationSettings @@ -80,6 +81,12 @@ async def initiate_oauth2( detail=f"OAuth not configured for source: {short_name}", ) + if not isinstance(oauth_settings, OAuth2Settings): + raise HTTPException( + status_code=400, + detail=f"Source {short_name} is not configured for OAuth2", + ) + api_callback = f"{self._settings.api_url}/source-connections/callback" try: @@ -131,6 +138,11 @@ async def initiate_oauth1( api_callback = f"{self._settings.api_url}/source-connections/callback" effective_consumer_key = consumer_key or oauth_settings.consumer_key effective_consumer_secret = consumer_secret or oauth_settings.consumer_secret + if not effective_consumer_secret: + raise HTTPException( + status_code=400, + detail=f"Missing consumer_secret for OAuth1 source: {short_name}", + ) request_token_response = await self._oauth1_service.get_request_token( request_token_url=oauth_settings.request_token_url, @@ -252,6 +264,12 @@ async def complete_oauth1_callback( """Exchange OAuth1 verifier for access token.""" ctx.logger.info(f"Exchanging OAuth1 verifier for access token: {short_name}") + if not oauth_settings.consumer_secret: + raise HTTPException( + status_code=400, + detail=f"Missing consumer_secret for OAuth1 source: {short_name}", + ) + return await self._oauth1_service.exchange_token( access_token_url=oauth_settings.access_token_url, consumer_key=oauth_settings.consumer_key, @@ -303,7 +321,7 @@ async def create_init_session( if additional_overrides: overrides.update(additional_overrides) - expires_at = datetime.now(timezone.utc) + timedelta(minutes=30) + expires_at = ConnectionInitSession.default_expires_at() return await self._init_session_repo.create( db, @@ -333,8 +351,7 @@ async def create_proxy_url( Returns: (proxy_url, proxy_expires, redirect_session_id) """ - proxy_ttl = 1440 # 24 hours - proxy_expires = datetime.now(timezone.utc) + timedelta(minutes=proxy_ttl) + proxy_expires = RedirectSession.default_expires_at() code8 = await self._redirect_session_repo.generate_unique_code(db, length=8) redirect_sess = await self._redirect_session_repo.create( diff --git a/backend/airweave/domains/oauth/tests/test_flow_service.py b/backend/airweave/domains/oauth/tests/test_flow_service.py index cd0ebfc09..717e84837 100644 --- a/backend/airweave/domains/oauth/tests/test_flow_service.py +++ b/backend/airweave/domains/oauth/tests/test_flow_service.py @@ -25,8 +25,9 @@ ) from airweave.domains.oauth.flow_service import OAuthFlowService from airweave.domains.oauth.types import OAuth1TokenResponse -from airweave.models.connection_init_session import ConnectionInitStatus -from airweave.platform.auth.schemas import OAuth1Settings, OAuth2TokenResponse +from airweave.models.connection_init_session import ConnectionInitSession, ConnectionInitStatus +from airweave.models.redirect_session import RedirectSession +from airweave.platform.auth.schemas import OAuth1Settings, OAuth2Settings, OAuth2TokenResponse from airweave.schemas.organization import Organization NOW = datetime.now(timezone.utc) @@ -49,15 +50,20 @@ def _ctx() -> ApiContext: def _settings(**overrides): - defaults = dict(api_url="https://api.test.com", app_url="https://app.test.com") + defaults = {"api_url": "https://api.test.com", "app_url": "https://app.test.com"} defaults.update(overrides) return SimpleNamespace(**defaults) def _oauth2_settings(): - return SimpleNamespace( - authorization_url="https://provider.com/auth", - token_url="https://provider.com/token", + return OAuth2Settings( + integration_short_name="github", + url="https://provider.com/auth", + backend_url="https://provider.com/token", + grant_type="authorization_code", + client_id="test-client-id", + content_type="application/json", + client_credential_location="body", scope="read", ) @@ -176,7 +182,10 @@ async def test_redirect_uri_uses_api_url(self): await svc.initiate_oauth2("github", "state", ctx=_ctx()) call_kwargs = oauth2_svc.generate_auth_url_with_redirect.call_args - assert call_kwargs.kwargs["redirect_uri"] == "https://custom-api.com/source-connections/callback" + assert ( + call_kwargs.kwargs["redirect_uri"] + == "https://custom-api.com/source-connections/callback" + ) async def test_value_error_from_oauth2_service_maps_to_422(self): oauth2_svc = AsyncMock() @@ -204,7 +213,9 @@ async def test_happy_path_returns_url_and_overrides(self): oauth1_svc.get_request_token = AsyncMock( return_value=OAuth1TokenResponse(oauth_token="req_tok", oauth_token_secret="req_sec") ) - oauth1_svc.build_authorization_url = MagicMock(return_value="https://provider.com/auth?oauth_token=req_tok") + oauth1_svc.build_authorization_url = MagicMock( + return_value="https://provider.com/auth?oauth_token=req_tok" + ) int_settings = AsyncMock() int_settings.get_by_short_name = AsyncMock(return_value=_oauth1_settings()) @@ -237,9 +248,7 @@ async def test_non_oauth1_settings_raises_400(self): svc = _service(integration_settings=int_settings) with pytest.raises(HTTPException) as exc_info: - await svc.initiate_oauth1( - "github", consumer_key="ck", consumer_secret="cs", ctx=_ctx() - ) + await svc.initiate_oauth1("github", consumer_key="ck", consumer_secret="cs", ctx=_ctx()) assert exc_info.value.status_code == 400 assert "not configured for OAuth1" in exc_info.value.detail @@ -341,7 +350,8 @@ async def test_delegates_to_oauth2_service(self): ) assert result.access_token == "at" - call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs + exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect + call_kwargs = exchange.call_args.kwargs assert call_kwargs["redirect_uri"] == "https://custom/cb" assert call_kwargs["code"] == "code123" @@ -357,7 +367,8 @@ async def test_falls_back_to_api_url_when_no_override(self): ) await svc.complete_oauth2_callback("github", "code123", {}, _ctx()) - call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs + exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect + call_kwargs = exchange.call_args.kwargs assert call_kwargs["redirect_uri"] == "https://fallback-api.com/source-connections/callback" async def test_passes_pkce_verifier_and_template_configs(self): @@ -375,7 +386,8 @@ async def test_passes_pkce_verifier_and_template_configs(self): } await svc.complete_oauth2_callback("github", "code123", overrides, _ctx()) - call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs + exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect + call_kwargs = exchange.call_args.kwargs assert call_kwargs["code_verifier"] == "pkce_v" assert call_kwargs["template_configs"] == {"domain": "acme"} assert call_kwargs["client_id"] == "cid" @@ -397,7 +409,9 @@ async def test_delegates_to_oauth1_service(self): settings = _oauth1_settings() overrides = {"oauth_token": "req_tok", "oauth_token_secret": "req_sec"} - result = await svc.complete_oauth1_callback("twitter", "verifier", overrides, settings, _ctx()) + result = await svc.complete_oauth1_callback( + "twitter", "verifier", overrides, settings, _ctx() + ) assert result.oauth_token == "access_tok" assert result.oauth_token_secret == "access_sec" @@ -512,6 +526,27 @@ async def test_redirect_url_defaults_to_none_when_not_provided(self): _, obj_in = init_repo._calls[0] assert obj_in["overrides"]["redirect_url"] is None + async def test_expires_at_within_five_minutes(self): + init_repo = FakeOAuthInitSessionRepository() + svc = _service(init_session_repo=init_repo) + db = AsyncMock() + uow = MagicMock() + + before = datetime.now(timezone.utc) + await svc.create_init_session( + db, + short_name="github", + state="state-1", + payload={}, + ctx=_ctx(), + uow=uow, + ) + after = datetime.now(timezone.utc) + + _, obj_in = init_repo._calls[0] + expires_at = obj_in["expires_at"] + assert before + timedelta(minutes=5) <= expires_at <= after + timedelta(minutes=5) + # --------------------------------------------------------------------------- # create_proxy_url @@ -534,3 +569,52 @@ async def test_returns_proxy_url_with_code(self): assert proxy_url.startswith("https://api.test.com/source-connections/authorize/") assert expires > datetime.now(timezone.utc) assert session_id is not None + + async def test_proxy_expires_within_five_minutes(self): + redirect_repo = FakeOAuthRedirectSessionRepository() + svc = _service( + redirect_session_repo=redirect_repo, + settings=_settings(api_url="https://api.test.com"), + ) + db = AsyncMock() + + before = datetime.now(timezone.utc) + _, expires, _ = await svc.create_proxy_url(db, "https://provider.com/auth?tok=1", _ctx()) + after = datetime.now(timezone.utc) + + assert before + timedelta(minutes=5) <= expires <= after + timedelta(minutes=5) + + +# --------------------------------------------------------------------------- +# default_expires_at model methods +# --------------------------------------------------------------------------- + + +class TestDefaultExpiresAt: + def test_connection_init_session_defaults_to_five_minutes(self): + before = datetime.now(timezone.utc) + result = ConnectionInitSession.default_expires_at() + after = datetime.now(timezone.utc) + + assert before + timedelta(minutes=5) <= result <= after + timedelta(minutes=5) + + def test_redirect_session_defaults_to_five_minutes(self): + before = datetime.now(timezone.utc) + result = RedirectSession.default_expires_at() + after = datetime.now(timezone.utc) + + assert before + timedelta(minutes=5) <= result <= after + timedelta(minutes=5) + + def test_custom_override_minutes(self): + before = datetime.now(timezone.utc) + result = ConnectionInitSession.default_expires_at(minutes=10) + after = datetime.now(timezone.utc) + + assert before + timedelta(minutes=10) <= result <= after + timedelta(minutes=10) + + def test_returns_utc_aware_datetime(self): + init_result = ConnectionInitSession.default_expires_at() + redirect_result = RedirectSession.default_expires_at() + + assert init_result.tzinfo is timezone.utc + assert redirect_result.tzinfo is timezone.utc diff --git a/backend/airweave/models/connection_init_session.py b/backend/airweave/models/connection_init_session.py index 1739a468d..7b19306e3 100644 --- a/backend/airweave/models/connection_init_session.py +++ b/backend/airweave/models/connection_init_session.py @@ -51,7 +51,7 @@ class ConnectionInitSession(OrganizationBase): String, nullable=False, default=ConnectionInitStatus.PENDING ) - # Expiration for security; default TTL ~30 minutes can be applied at creation + # Expiration for security; default TTL ~5 minutes can be applied at creation expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) # Set when finalized (optional) @@ -73,6 +73,6 @@ class ConnectionInitSession(OrganizationBase): __table_args__ = (Index("idx_connection_init_session_expires_at", "expires_at"),) @staticmethod - def default_expires_at(minutes: int = 30) -> datetime: + def default_expires_at(minutes: int = 5) -> datetime: """Return a UTC expiry timestamp ``minutes`` from now.""" return datetime.now(timezone.utc) + timedelta(minutes=minutes)