Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/airweave/domains/oauth/fakes/flow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,6 @@ async def create_proxy_url(
self._calls.append(("create_proxy_url", provider_auth_url))
return (
"https://api.example.com/source-connections/authorize/abc12345",
datetime.now(timezone.utc) + timedelta(hours=24),
datetime.now(timezone.utc) + timedelta(minutes=5),
uuid4(),
)
27 changes: 22 additions & 5 deletions backend/airweave/domains/oauth/flow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Does NOT know about source connections, credentials, or syncs.
"""

from datetime import datetime, timedelta, timezone
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
from uuid import UUID

Expand All @@ -22,7 +22,8 @@
)
from airweave.domains.oauth.types import OAuth1TokenResponse, OAuthBrowserInitiationResult
from airweave.models.connection_init_session import ConnectionInitSession, ConnectionInitStatus
from airweave.platform.auth.schemas import OAuth1Settings, OAuth2TokenResponse
from airweave.models.redirect_session import RedirectSession
from airweave.platform.auth.schemas import OAuth1Settings, OAuth2Settings, OAuth2TokenResponse
from airweave.platform.auth.settings import IntegrationSettings


Expand Down Expand Up @@ -80,6 +81,12 @@ async def initiate_oauth2(
detail=f"OAuth not configured for source: {short_name}",
)

if not isinstance(oauth_settings, OAuth2Settings):
raise HTTPException(
status_code=400,
detail=f"Source {short_name} is not configured for OAuth2",
)

api_callback = f"{self._settings.api_url}/source-connections/callback"

try:
Expand Down Expand Up @@ -131,6 +138,11 @@ async def initiate_oauth1(
api_callback = f"{self._settings.api_url}/source-connections/callback"
effective_consumer_key = consumer_key or oauth_settings.consumer_key
effective_consumer_secret = consumer_secret or oauth_settings.consumer_secret
if not effective_consumer_secret:
raise HTTPException(
status_code=400,
detail=f"Missing consumer_secret for OAuth1 source: {short_name}",
)

request_token_response = await self._oauth1_service.get_request_token(
request_token_url=oauth_settings.request_token_url,
Expand Down Expand Up @@ -252,6 +264,12 @@ async def complete_oauth1_callback(
"""Exchange OAuth1 verifier for access token."""
ctx.logger.info(f"Exchanging OAuth1 verifier for access token: {short_name}")

if not oauth_settings.consumer_secret:
raise HTTPException(
status_code=400,
detail=f"Missing consumer_secret for OAuth1 source: {short_name}",
)

return await self._oauth1_service.exchange_token(
access_token_url=oauth_settings.access_token_url,
consumer_key=oauth_settings.consumer_key,
Expand Down Expand Up @@ -303,7 +321,7 @@ async def create_init_session(
if additional_overrides:
overrides.update(additional_overrides)

expires_at = datetime.now(timezone.utc) + timedelta(minutes=30)
expires_at = ConnectionInitSession.default_expires_at()

return await self._init_session_repo.create(
db,
Expand Down Expand Up @@ -333,8 +351,7 @@ async def create_proxy_url(
Returns:
(proxy_url, proxy_expires, redirect_session_id)
"""
proxy_ttl = 1440 # 24 hours
proxy_expires = datetime.now(timezone.utc) + timedelta(minutes=proxy_ttl)
proxy_expires = RedirectSession.default_expires_at()
Comment on lines 324 to +354
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

code8 = await self._redirect_session_repo.generate_unique_code(db, length=8)

redirect_sess = await self._redirect_session_repo.create(
Expand Down
114 changes: 99 additions & 15 deletions backend/airweave/domains/oauth/tests/test_flow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
)
from airweave.domains.oauth.flow_service import OAuthFlowService
from airweave.domains.oauth.types import OAuth1TokenResponse
from airweave.models.connection_init_session import ConnectionInitStatus
from airweave.platform.auth.schemas import OAuth1Settings, OAuth2TokenResponse
from airweave.models.connection_init_session import ConnectionInitSession, ConnectionInitStatus
from airweave.models.redirect_session import RedirectSession
from airweave.platform.auth.schemas import OAuth1Settings, OAuth2Settings, OAuth2TokenResponse
from airweave.schemas.organization import Organization

NOW = datetime.now(timezone.utc)
Expand All @@ -49,15 +50,20 @@ def _ctx() -> ApiContext:


def _settings(**overrides):
defaults = dict(api_url="https://api.test.com", app_url="https://app.test.com")
defaults = {"api_url": "https://api.test.com", "app_url": "https://app.test.com"}
defaults.update(overrides)
return SimpleNamespace(**defaults)


def _oauth2_settings():
return SimpleNamespace(
authorization_url="https://provider.com/auth",
token_url="https://provider.com/token",
return OAuth2Settings(
integration_short_name="github",
url="https://provider.com/auth",
backend_url="https://provider.com/token",
grant_type="authorization_code",
client_id="test-client-id",
content_type="application/json",
client_credential_location="body",
scope="read",
)

Expand Down Expand Up @@ -176,7 +182,10 @@ async def test_redirect_uri_uses_api_url(self):
await svc.initiate_oauth2("github", "state", ctx=_ctx())

call_kwargs = oauth2_svc.generate_auth_url_with_redirect.call_args
assert call_kwargs.kwargs["redirect_uri"] == "https://custom-api.com/source-connections/callback"
assert (
call_kwargs.kwargs["redirect_uri"]
== "https://custom-api.com/source-connections/callback"
)

async def test_value_error_from_oauth2_service_maps_to_422(self):
oauth2_svc = AsyncMock()
Expand Down Expand Up @@ -204,7 +213,9 @@ async def test_happy_path_returns_url_and_overrides(self):
oauth1_svc.get_request_token = AsyncMock(
return_value=OAuth1TokenResponse(oauth_token="req_tok", oauth_token_secret="req_sec")
)
oauth1_svc.build_authorization_url = MagicMock(return_value="https://provider.com/auth?oauth_token=req_tok")
oauth1_svc.build_authorization_url = MagicMock(
return_value="https://provider.com/auth?oauth_token=req_tok"
)

int_settings = AsyncMock()
int_settings.get_by_short_name = AsyncMock(return_value=_oauth1_settings())
Expand Down Expand Up @@ -237,9 +248,7 @@ async def test_non_oauth1_settings_raises_400(self):

svc = _service(integration_settings=int_settings)
with pytest.raises(HTTPException) as exc_info:
await svc.initiate_oauth1(
"github", consumer_key="ck", consumer_secret="cs", ctx=_ctx()
)
await svc.initiate_oauth1("github", consumer_key="ck", consumer_secret="cs", ctx=_ctx())
assert exc_info.value.status_code == 400
assert "not configured for OAuth1" in exc_info.value.detail

Expand Down Expand Up @@ -341,7 +350,8 @@ async def test_delegates_to_oauth2_service(self):
)

assert result.access_token == "at"
call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs
exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect
call_kwargs = exchange.call_args.kwargs
assert call_kwargs["redirect_uri"] == "https://custom/cb"
assert call_kwargs["code"] == "code123"

Expand All @@ -357,7 +367,8 @@ async def test_falls_back_to_api_url_when_no_override(self):
)
await svc.complete_oauth2_callback("github", "code123", {}, _ctx())

call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs
exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect
call_kwargs = exchange.call_args.kwargs
assert call_kwargs["redirect_uri"] == "https://fallback-api.com/source-connections/callback"

async def test_passes_pkce_verifier_and_template_configs(self):
Expand All @@ -375,7 +386,8 @@ async def test_passes_pkce_verifier_and_template_configs(self):
}
await svc.complete_oauth2_callback("github", "code123", overrides, _ctx())

call_kwargs = oauth2_svc.exchange_authorization_code_for_token_with_redirect.call_args.kwargs
exchange = oauth2_svc.exchange_authorization_code_for_token_with_redirect
call_kwargs = exchange.call_args.kwargs
assert call_kwargs["code_verifier"] == "pkce_v"
assert call_kwargs["template_configs"] == {"domain": "acme"}
assert call_kwargs["client_id"] == "cid"
Expand All @@ -397,7 +409,9 @@ async def test_delegates_to_oauth1_service(self):
settings = _oauth1_settings()
overrides = {"oauth_token": "req_tok", "oauth_token_secret": "req_sec"}

result = await svc.complete_oauth1_callback("twitter", "verifier", overrides, settings, _ctx())
result = await svc.complete_oauth1_callback(
"twitter", "verifier", overrides, settings, _ctx()
)

assert result.oauth_token == "access_tok"
assert result.oauth_token_secret == "access_sec"
Expand Down Expand Up @@ -512,6 +526,27 @@ async def test_redirect_url_defaults_to_none_when_not_provided(self):
_, obj_in = init_repo._calls[0]
assert obj_in["overrides"]["redirect_url"] is None

async def test_expires_at_within_five_minutes(self):
init_repo = FakeOAuthInitSessionRepository()
svc = _service(init_session_repo=init_repo)
db = AsyncMock()
uow = MagicMock()

before = datetime.now(timezone.utc)
await svc.create_init_session(
db,
short_name="github",
state="state-1",
payload={},
ctx=_ctx(),
uow=uow,
)
after = datetime.now(timezone.utc)

_, obj_in = init_repo._calls[0]
expires_at = obj_in["expires_at"]
assert before + timedelta(minutes=5) <= expires_at <= after + timedelta(minutes=5)


# ---------------------------------------------------------------------------
# create_proxy_url
Expand All @@ -534,3 +569,52 @@ async def test_returns_proxy_url_with_code(self):
assert proxy_url.startswith("https://api.test.com/source-connections/authorize/")
assert expires > datetime.now(timezone.utc)
assert session_id is not None

async def test_proxy_expires_within_five_minutes(self):
redirect_repo = FakeOAuthRedirectSessionRepository()
svc = _service(
redirect_session_repo=redirect_repo,
settings=_settings(api_url="https://api.test.com"),
)
db = AsyncMock()

before = datetime.now(timezone.utc)
_, expires, _ = await svc.create_proxy_url(db, "https://provider.com/auth?tok=1", _ctx())
after = datetime.now(timezone.utc)

assert before + timedelta(minutes=5) <= expires <= after + timedelta(minutes=5)


# ---------------------------------------------------------------------------
# default_expires_at model methods
# ---------------------------------------------------------------------------


class TestDefaultExpiresAt:
def test_connection_init_session_defaults_to_five_minutes(self):
before = datetime.now(timezone.utc)
result = ConnectionInitSession.default_expires_at()
after = datetime.now(timezone.utc)

assert before + timedelta(minutes=5) <= result <= after + timedelta(minutes=5)

def test_redirect_session_defaults_to_five_minutes(self):
before = datetime.now(timezone.utc)
result = RedirectSession.default_expires_at()
after = datetime.now(timezone.utc)

assert before + timedelta(minutes=5) <= result <= after + timedelta(minutes=5)

def test_custom_override_minutes(self):
before = datetime.now(timezone.utc)
result = ConnectionInitSession.default_expires_at(minutes=10)
after = datetime.now(timezone.utc)

assert before + timedelta(minutes=10) <= result <= after + timedelta(minutes=10)

def test_returns_utc_aware_datetime(self):
init_result = ConnectionInitSession.default_expires_at()
redirect_result = RedirectSession.default_expires_at()

assert init_result.tzinfo is timezone.utc
assert redirect_result.tzinfo is timezone.utc
4 changes: 2 additions & 2 deletions backend/airweave/models/connection_init_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConnectionInitSession(OrganizationBase):
String, nullable=False, default=ConnectionInitStatus.PENDING
)

# Expiration for security; default TTL ~30 minutes can be applied at creation
# Expiration for security; default TTL ~5 minutes can be applied at creation
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

# Set when finalized (optional)
Expand All @@ -73,6 +73,6 @@ class ConnectionInitSession(OrganizationBase):
__table_args__ = (Index("idx_connection_init_session_expires_at", "expires_at"),)

@staticmethod
def default_expires_at(minutes: int = 30) -> datetime:
def default_expires_at(minutes: int = 5) -> datetime:
"""Return a UTC expiry timestamp ``minutes`` from now."""
return datetime.now(timezone.utc) + timedelta(minutes=minutes)
Loading