Skip to content
Open
19 changes: 14 additions & 5 deletions backend/airweave/api/v1/endpoints/file_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from airweave import crud
from airweave.api import deps
from airweave.api.context import ApiContext
from airweave.api.inject import Inject
from airweave.api.router import TrailingSlashRouter
from airweave.platform.storage import sync_file_manager
from airweave.domains.storage.protocols import SyncFileManagerProtocol

router = TrailingSlashRouter()

Expand Down Expand Up @@ -55,13 +56,15 @@ async def download_file(
entity_id: str,
ctx: ApiContext = Depends(deps.get_context),
db: AsyncSession = Depends(deps.get_db),
sfm: SyncFileManagerProtocol = Inject(SyncFileManagerProtocol),
) -> FileResponse:
"""Download a file by entity ID.

Args:
entity_id: The entity ID
ctx: The current authentication context
db: Database session
sfm: Sync file manager (injected)

Returns:
FileResponse: The file content
Expand All @@ -74,7 +77,7 @@ async def download_file(

try:
# Download to temp file
content, file_path = await sync_file_manager.download_ctti_file(
content, file_path = await sfm.download_ctti_file(
ctx.logger,
entity_id,
output_path=f"/tmp/{entity_id.replace(':', '_').replace('/', '_')}.md",
Expand Down Expand Up @@ -108,13 +111,15 @@ async def get_file_content(
entity_id: str,
ctx: ApiContext = Depends(deps.get_context),
db: AsyncSession = Depends(deps.get_db),
sfm: SyncFileManagerProtocol = Inject(SyncFileManagerProtocol),
) -> dict:
"""Get file content as JSON response.

Args:
entity_id: The entity ID
ctx: The current authentication context
db: Database session
sfm: Sync file manager (injected)

Returns:
dict: JSON response with the file content
Expand All @@ -126,7 +131,7 @@ async def get_file_content(
await verify_picnic_health_access(ctx, db)

try:
content = await sync_file_manager.get_ctti_file_content(ctx.logger, entity_id)
content = await sfm.get_ctti_file_content(ctx.logger, entity_id)

if content is None:
raise HTTPException(
Expand Down Expand Up @@ -158,13 +163,15 @@ async def download_files_batch(
entity_ids: List[str],
ctx: ApiContext = Depends(deps.get_context),
db: AsyncSession = Depends(deps.get_db),
sfm: SyncFileManagerProtocol = Inject(SyncFileManagerProtocol),
) -> StreamingResponse:
"""Download multiple files as a ZIP archive.

Args:
entity_ids: List of entity IDs to download
ctx: The current authentication context
db: Database session
sfm: Sync file manager (injected)

Returns:
StreamingResponse: ZIP file containing all requested files
Expand All @@ -183,7 +190,7 @@ async def download_files_batch(

try:
# Download all files
results = await sync_file_manager.download_ctti_files_batch(
results = await sfm.download_ctti_files_batch(
ctx.logger, entity_ids, continue_on_error=True
)

Expand Down Expand Up @@ -242,13 +249,15 @@ async def check_files_exist(
entity_ids: List[str] = Query(..., description="List of entity IDs to check"),
ctx: ApiContext = Depends(deps.get_context),
db: AsyncSession = Depends(deps.get_db),
sfm: SyncFileManagerProtocol = Inject(SyncFileManagerProtocol),
) -> dict:
"""Check which files exist in storage.

Args:
entity_ids: List of entity IDs to check
ctx: The current authentication context
db: Database session
sfm: Sync file manager (injected)

Returns:
dict: Dictionary with entity_ids as keys and existence status as values
Expand All @@ -268,7 +277,7 @@ async def check_files_exist(

for entity_id in entity_ids:
try:
exists = await sync_file_manager.check_ctti_file_exists(ctx.logger, entity_id)
exists = await sfm.check_ctti_file_exists(ctx.logger, entity_id)
results[entity_id] = exists
except Exception as e:
ctx.logger.warning(f"Error checking file {entity_id}: {e}")
Expand Down
7 changes: 4 additions & 3 deletions backend/airweave/core/admin_sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def build(self):
class AdminSyncService:
"""Service for admin sync operations with optimized bulk fetching."""

MAX_CONCURRENT_DESTINATION_QUERIES = 10 # Reduced from 20 to prevent overload
MAX_CONCURRENT_DESTINATION_QUERIES = 10

async def list_syncs_with_metadata(
self,
Expand Down Expand Up @@ -367,9 +367,10 @@ async def _fetch_arf_counts(
return {s.id: None for s in syncs}

start = time.monotonic()
from airweave.platform.sync.arf.service import ArfService
# [code blue] todo: inject arf_service once admin domain is extracted
from airweave.core import container as container_mod

arf_service = ArfService()
arf_service = container_mod.container.arf_service
Comment on lines +370 to +373
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporary - let's defer restructuring the admin stuff


async def get_arf_count_safe(sync_id: UUID) -> Optional[int]:
try:
Expand Down
11 changes: 11 additions & 0 deletions backend/airweave/core/container/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from airweave.core.protocols.identity import IdentityProvider
from airweave.core.protocols.payment import PaymentGatewayProtocol
from airweave.domains.arf.protocols import ArfServiceProtocol
from airweave.domains.auth_provider.protocols import (
AuthProviderRegistryProtocol,
AuthProviderServiceProtocol,
Expand Down Expand Up @@ -75,6 +76,7 @@
SourceRegistryProtocol,
SourceServiceProtocol,
)
from airweave.domains.storage.protocols import StorageBackend, SyncFileManagerProtocol
from airweave.domains.syncs.protocols import (
SyncCursorRepositoryProtocol,
SyncJobRepositoryProtocol,
Expand Down Expand Up @@ -224,6 +226,15 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)):
# Connect domain service (session-based frontend integration flows)
connect_service: ConnectServiceProtocol

# Storage domain — unified backend for file/object storage
storage_backend: StorageBackend

# Storage domain — sync-aware file manager (CTTI, metadata, caching)
sync_file_manager: SyncFileManagerProtocol

# ARF domain — raw entity capture / replay service
arf_service: ArfServiceProtocol

# OCR provider (with fallback chain + circuit breaking)
# Optional: None when no OCR backend (Mistral/Docling) is configured
ocr_provider: Optional[OcrProvider] = None
Expand Down
15 changes: 15 additions & 0 deletions backend/airweave/core/container/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airweave.core.protocols.webhooks import WebhookPublisher
from airweave.core.redis_client import redis_client
from airweave.db.session import health_check_engine
from airweave.domains.arf.service import ArfService
from airweave.domains.auth_provider.registry import AuthProviderRegistry
from airweave.domains.auth_provider.service import AuthProviderService
from airweave.domains.browse_tree.repository import NodeSelectionRepository
Expand Down Expand Up @@ -89,6 +90,8 @@
from airweave.domains.sources.registry import SourceRegistry
from airweave.domains.sources.service import SourceService
from airweave.domains.sources.validation import SourceValidationService
from airweave.domains.storage.factory import get_storage_backend
from airweave.domains.storage.sync_file_manager import SyncFileManager
from airweave.domains.syncs.sync_cursor_repository import SyncCursorRepository
from airweave.domains.syncs.sync_job_repository import SyncJobRepository
from airweave.domains.syncs.sync_job_service import SyncJobService
Expand Down Expand Up @@ -425,11 +428,23 @@ def create_container(settings: Settings) -> Container:
temporal_workflow_service=sync_deps["temporal_workflow_service"],
)

# Storage domain
# -----------------------------------------------------------------
storage_backend = get_storage_backend()
sync_file_manager = SyncFileManager(backend=storage_backend)

# ARF domain service (raw entity capture / replay)
# -----------------------------------------------------------------
arf_service = ArfService(storage=storage_backend)

# -----------------------------------------------------------------
# Usage billing listener
# -----------------------------------------------------------------

return Container(
storage_backend=storage_backend,
sync_file_manager=sync_file_manager,
arf_service=arf_service,
context_cache=context_cache,
rate_limiter=rate_limiter,
billing_service=billing_services["billing_service"],
Expand Down
14 changes: 14 additions & 0 deletions backend/airweave/domains/arf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""ARF (Airweave Raw Format) domain.

Raw entity capture for replay, debugging, and evals.
"""

from airweave.domains.arf.protocols import ArfReaderProtocol, ArfServiceProtocol
from airweave.domains.arf.types import EntitySerializationMeta, SyncManifest

__all__ = [
"ArfServiceProtocol",
"ArfReaderProtocol",
"SyncManifest",
"EntitySerializationMeta",
]
Empty file.
77 changes: 77 additions & 0 deletions backend/airweave/domains/arf/fakes/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""In-memory fake for ArfReaderProtocol."""

from typing import Any, AsyncGenerator, Dict, List, Optional

from airweave.platform.entities._base import BaseEntity


class FakeArfReader:
"""In-memory fake for ArfReaderProtocol.

Returns pre-seeded entities and manifest data.
"""

def __init__(self) -> None:
self._entities: List[BaseEntity] = []
self._manifest: Optional[Dict[str, Any]] = None
self._valid: bool = True
self._calls: List[tuple] = []
self._should_raise: Optional[Exception] = None

# -- Test helpers ----------------------------------------------------------

def seed_entities(self, entities: List[BaseEntity]) -> None:
"""Pre-populate entities to return during iteration."""
self._entities = list(entities)

def seed_manifest(self, manifest: Dict[str, Any]) -> None:
"""Pre-populate the manifest dict."""
self._manifest = manifest

def set_valid(self, valid: bool) -> None:
"""Control what validate() returns."""
self._valid = valid

def set_error(self, error: Exception) -> None:
"""Configure next call to raise."""
self._should_raise = error

def get_calls(self, method: str) -> List[tuple]:
"""Return calls for a specific method."""
return [c for c in self._calls if c[0] == method]

# -- Protocol implementation -----------------------------------------------

async def validate(self) -> bool:
self._calls.append(("validate",))
if self._should_raise:
exc, self._should_raise = self._should_raise, None
raise exc
return self._valid

async def read_manifest(self) -> Dict[str, Any]:
self._calls.append(("read_manifest",))
if self._should_raise:
exc, self._should_raise = self._should_raise, None
raise exc
if self._manifest is None:
raise FileNotFoundError("No manifest seeded")
return self._manifest

async def get_entity_count(self) -> int:
self._calls.append(("get_entity_count",))
if self._should_raise:
exc, self._should_raise = self._should_raise, None
raise exc
return len(self._entities)

async def iter_entities(self) -> AsyncGenerator[BaseEntity, None]:
self._calls.append(("iter_entities",))
if self._should_raise:
exc, self._should_raise = self._should_raise, None
raise exc
for entity in self._entities:
yield entity

def cleanup(self) -> None:
self._calls.append(("cleanup",))
Loading
Loading