From 760fca147f4e440c95c6e7dfa483933b2bccf53a Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 16:14:05 -0700 Subject: [PATCH 01/13] refactor: protocolize sync pipeline and move to domains/sync_pipeline Move SyncFactory, EntityPipeline, EntityActionResolver, EntityActionDispatcher to domains/sync_pipeline/ with protocols. Convert SyncFactory from classmethods to instance with injected deps (sc_repo, event_bus, embedders, entity_repo). Eliminate get_source_connection_id indirection chain. Remove direct embedder and access_token passing from SyncService/RunSyncActivity. Wire into DI container. --- backend/airweave/core/container/container.py | 10 +- backend/airweave/core/container/factory.py | 30 ++++- .../domains/entities/entity_repository.py | 29 +++++ .../airweave/domains/entities/protocols.py | 21 +++- .../domains/sync_pipeline/__init__.py | 1 + .../entity_action_dispatcher.py} | 60 +--------- .../sync_pipeline/entity_action_resolver.py} | 104 +++-------------- .../sync_pipeline}/entity_pipeline.py | 54 +++------ .../sync => domains/sync_pipeline}/factory.py | 89 ++++++++------ .../domains/sync_pipeline/fakes/__init__.py | 0 .../sync_pipeline/fakes/entity_repository.py | 27 +++++ .../domains/sync_pipeline/fakes/factory.py | 32 +++++ .../domains/sync_pipeline/protocols.py | 109 ++++++++++++++++++ .../domains/sync_pipeline/tests/__init__.py | 0 .../domains/syncs/fakes/sync_service.py | 6 - backend/airweave/domains/syncs/protocols.py | 4 - backend/airweave/domains/syncs/service.py | 38 ++---- .../domains/syncs/tests/test_sync_service.py | 96 +++++++-------- backend/airweave/platform/builders/source.py | 24 ---- backend/airweave/platform/builders/sync.py | 12 -- .../platform/sync/actions/entity/__init__.py | 4 +- .../platform/sync/actions/entity/builder.py | 2 +- .../airweave/platform/sync/orchestrator.py | 2 +- .../platform/temporal/activities/sync.py | 6 - .../platform/temporal/worker/wiring.py | 4 - backend/conftest.py | 20 ++++ 26 files changed, 412 insertions(+), 372 deletions(-) create mode 100644 backend/airweave/domains/entities/entity_repository.py create mode 100644 backend/airweave/domains/sync_pipeline/__init__.py rename backend/airweave/{platform/sync/actions/entity/dispatcher.py => domains/sync_pipeline/entity_action_dispatcher.py} (78%) rename backend/airweave/{platform/sync/actions/entity/resolver.py => domains/sync_pipeline/entity_action_resolver.py} (78%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/entity_pipeline.py (92%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/factory.py (79%) create mode 100644 backend/airweave/domains/sync_pipeline/fakes/__init__.py create mode 100644 backend/airweave/domains/sync_pipeline/fakes/entity_repository.py create mode 100644 backend/airweave/domains/sync_pipeline/fakes/factory.py create mode 100644 backend/airweave/domains/sync_pipeline/protocols.py create mode 100644 backend/airweave/domains/sync_pipeline/tests/__init__.py diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index bcbff5222..60ae1780a 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -52,7 +52,10 @@ SparseEmbedderProtocol, SparseEmbedderRegistryProtocol, ) -from airweave.domains.entities.protocols import EntityDefinitionRegistryProtocol +from airweave.domains.entities.protocols import ( + EntityDefinitionRegistryProtocol, + EntityRepositoryProtocol, +) from airweave.domains.oauth.protocols import ( OAuth1ServiceProtocol, OAuth2ServiceProtocol, @@ -75,6 +78,7 @@ SourceRegistryProtocol, SourceServiceProtocol, ) +from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol from airweave.domains.syncs.protocols import ( SyncCursorRepositoryProtocol, SyncJobRepositoryProtocol, @@ -186,6 +190,10 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): sync_job_service: SyncJobServiceProtocol sync_service: SyncServiceProtocol sync_lifecycle: SyncLifecycleServiceProtocol + sync_factory: SyncFactoryProtocol + + # Entity repository (used by sync pipeline) + entity_repo: EntityRepositoryProtocol # Temporal domain temporal_workflow_service: TemporalWorkflowServiceProtocol diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index f95e91d3f..f34fb4f0c 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -67,6 +67,7 @@ FastEmbedSparseEmbedder as DomainFastEmbedSparseEmbedder, ) from airweave.domains.entities.entity_count_repository import EntityCountRepository +from airweave.domains.entities.entity_repository import EntityRepository from airweave.domains.entities.registry import EntityDefinitionRegistry from airweave.domains.oauth.callback_service import OAuthCallbackService from airweave.domains.oauth.flow_service import OAuthFlowService @@ -89,6 +90,8 @@ from airweave.domains.sources.registry import SourceRegistry from airweave.domains.sources.service import SourceService from airweave.domains.sources.validation import SourceValidationService +from airweave.domains.sync_pipeline.factory import SyncFactory +from airweave.domains.syncs.service import SyncService from airweave.domains.syncs.sync_cursor_repository import SyncCursorRepository from airweave.domains.syncs.sync_job_repository import SyncJobRepository from airweave.domains.syncs.sync_job_service import SyncJobService @@ -375,6 +378,23 @@ def create_container(settings: Settings) -> Container: dense_embedder = _create_dense_embedder(settings, dense_embedder_registry) sparse_embedder = _create_sparse_embedder(sparse_embedder_registry) + # ----------------------------------------------------------------- + # Sync factory + service (needs embedders, built after embedder init) + # ----------------------------------------------------------------- + sync_factory = SyncFactory( + sc_repo=source_deps["sc_repo"], + event_bus=event_bus, + usage_checker=usage_checker, + dense_embedder=dense_embedder, + sparse_embedder=sparse_embedder, + entity_repo=sync_deps["entity_repo"], + ) + + sync_service = SyncService( + sync_job_service=sync_deps["sync_job_service"], + sync_factory=sync_factory, + ) + # ----------------------------------------------------------------- # Collection service (needs collection_repo, sc_repo, sync_lifecycle, dense_registry) # ----------------------------------------------------------------- @@ -477,8 +497,10 @@ def create_container(settings: Settings) -> Container: payment_gateway=billing_services["payment_gateway"], sync_record_service=sync_deps["sync_record_service"], sync_job_service=sync_deps["sync_job_service"], - sync_service=sync_deps["sync_service"], + sync_service=sync_service, sync_lifecycle=sync_deps["sync_lifecycle"], + sync_factory=sync_factory, + entity_repo=sync_deps["entity_repo"], temporal_workflow_service=sync_deps["temporal_workflow_service"], temporal_schedule_service=sync_deps["temporal_schedule_service"], usage_checker=usage_checker, @@ -852,12 +874,10 @@ def _create_sync_services( 4. SyncLifecycleService (needs everything above) """ entity_count_repo = EntityCountRepository() + entity_repo = EntityRepository() sync_job_service = SyncJobService(sync_job_repo=sync_job_repo) - from airweave.domains.syncs.service import SyncService - - sync_service = SyncService(sync_job_service=sync_job_service) temporal_workflow_service = TemporalWorkflowService() sync_record_service = SyncRecordService( @@ -899,11 +919,11 @@ def _create_sync_services( return { "sync_record_service": sync_record_service, "sync_job_service": sync_job_service, - "sync_service": sync_service, "sync_lifecycle": sync_lifecycle, "temporal_workflow_service": temporal_workflow_service, "temporal_schedule_service": temporal_schedule_service, "response_builder": response_builder, + "entity_repo": entity_repo, } diff --git a/backend/airweave/domains/entities/entity_repository.py b/backend/airweave/domains/entities/entity_repository.py new file mode 100644 index 000000000..35fa3f25b --- /dev/null +++ b/backend/airweave/domains/entities/entity_repository.py @@ -0,0 +1,29 @@ +"""Entity repository wrapping crud.entity for sync pipeline usage.""" + +from typing import Dict, List, Tuple +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud +from airweave.models.entity import Entity + + +class EntityRepository: + """Delegates to the crud.entity singleton.""" + + async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: + """Get all entities for a specific sync.""" + return await crud.entity.get_by_sync_id(db, sync_id) + + async def bulk_get_by_entity_sync_and_definition( + self, + db: AsyncSession, + *, + sync_id: UUID, + entity_requests: list[Tuple[str, str]], + ) -> Dict[Tuple[str, str], Entity]: + """Bulk-fetch entities by (entity_id, definition_short_name).""" + return await crud.entity.bulk_get_by_entity_sync_and_definition( + db, sync_id=sync_id, entity_requests=entity_requests + ) diff --git a/backend/airweave/domains/entities/protocols.py b/backend/airweave/domains/entities/protocols.py index 792df5b46..5d7617f13 100644 --- a/backend/airweave/domains/entities/protocols.py +++ b/backend/airweave/domains/entities/protocols.py @@ -1,12 +1,13 @@ """Protocols for the entities domain.""" -from typing import List, Protocol +from typing import Dict, List, Protocol, Tuple from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from airweave.core.protocols.registry import RegistryProtocol from airweave.domains.entities.types import EntityDefinitionEntry +from airweave.models.entity import Entity from airweave.schemas.entity_count import EntityCountWithDefinition @@ -26,3 +27,21 @@ async def get_counts_per_sync_and_type( ) -> List[EntityCountWithDefinition]: """Get entity counts for a sync grouped by entity definition.""" ... + + +class EntityRepositoryProtocol(Protocol): + """Entity read access used by the sync pipeline.""" + + async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: + """Get all entities for a specific sync.""" + ... + + async def bulk_get_by_entity_sync_and_definition( + self, + db: AsyncSession, + *, + sync_id: UUID, + entity_requests: list[Tuple[str, str]], + ) -> Dict[Tuple[str, str], Entity]: + """Bulk-fetch entities by (entity_id, entity_definition_short_name) for a sync.""" + ... diff --git a/backend/airweave/domains/sync_pipeline/__init__.py b/backend/airweave/domains/sync_pipeline/__init__.py new file mode 100644 index 000000000..67742611f --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/__init__.py @@ -0,0 +1 @@ +"""Sync pipeline domain — protocols and implementations for sync execution.""" diff --git a/backend/airweave/platform/sync/actions/entity/dispatcher.py b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py similarity index 78% rename from backend/airweave/platform/sync/actions/entity/dispatcher.py rename to backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py index a5b9ec8bb..b37bde8ef 100644 --- a/backend/airweave/platform/sync/actions/entity/dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py @@ -33,14 +33,7 @@ class EntityActionDispatcher: """ def __init__(self, handlers: List[EntityActionHandler]): - """Initialize dispatcher with handlers. - - Args: - handlers: List of handlers to dispatch to (configured at factory time) - EntityPostgresHandler is automatically separated for - sequential execution after other handlers. - """ - # Separate postgres handler from destination handlers + """Initialize with handler list, separating Postgres from destinations.""" self._destination_handlers: List[EntityActionHandler] = [] self._postgres_handler: EntityPostgresHandler | None = None @@ -62,16 +55,6 @@ async def dispatch( ) -> None: """Dispatch action batch to all handlers. - Execution order: - 1. All destination handlers concurrently (Qdrant, RawData, etc.) - 2. If all succeed → PostgreSQL metadata handler - 3. If any fails → SyncFailureError propagates - - Args: - batch: Resolved and processed action batch (with chunk_entities populated) - sync_context: Sync context - runtime: Sync runtime with entity_tracker, source, etc. - Raises: SyncFailureError: If any handler fails """ @@ -84,10 +67,8 @@ async def dispatch( f"[EntityDispatcher] Dispatching {batch.summary()} to handlers: {handler_names}" ) - # Step 1: Execute destination handlers concurrently await self._dispatch_to_destinations(batch, sync_context, runtime) - # Step 2: Execute postgres handler (only after destinations succeed) if self._postgres_handler: await self._dispatch_to_postgres(batch, sync_context, runtime) @@ -100,25 +81,12 @@ async def dispatch_orphan_cleanup( ) -> None: """Dispatch orphan cleanup to ALL handlers concurrently. - Called at the end of sync for entities that exist in DB but were not - encountered during this sync run. - - Each handler independently cleans up its own storage: - - DestinationHandler → vector stores (Qdrant, Vespa) - - ArfHandler → ARF storage - - EntityPostgresHandler → postgres DB - - Args: - orphan_entity_ids: Entity IDs to clean up - sync_context: Sync context - Raises: SyncFailureError: If any handler fails cleanup """ if not orphan_entity_ids: return - # Collect ALL handlers all_handlers = list(self._destination_handlers) if self._postgres_handler: all_handlers.append(self._postgres_handler) @@ -131,7 +99,6 @@ async def dispatch_orphan_cleanup( f"to {len(all_handlers)} handlers" ) - # Execute all handlers concurrently tasks = [ asyncio.create_task( self._dispatch_orphan_to_handler(handler, orphan_entity_ids, sync_context), @@ -142,7 +109,6 @@ async def dispatch_orphan_cleanup( results = await asyncio.gather(*tasks, return_exceptions=True) - # Check for failures failures = [] for handler, result in zip(all_handlers, results, strict=False): if isinstance(result, Exception): @@ -166,18 +132,12 @@ async def _dispatch_to_destinations( ) -> None: """Dispatch to all destination handlers concurrently. - Args: - batch: Action batch - sync_context: Sync context - runtime: Sync runtime - Raises: SyncFailureError: If any destination handler fails """ if not self._destination_handlers: return - # Create tasks for all destination handlers tasks = [ asyncio.create_task( self._dispatch_to_handler(handler, batch, sync_context, runtime), @@ -186,10 +146,8 @@ async def _dispatch_to_destinations( for handler in self._destination_handlers ] - # Wait for all - if any fails, collect errors results = await asyncio.gather(*tasks, return_exceptions=True) - # Check for failures failures = [] for handler, result in zip(self._destination_handlers, results, strict=False): if isinstance(result, Exception): @@ -210,11 +168,6 @@ async def _dispatch_to_postgres( ) -> None: """Dispatch to PostgreSQL metadata handler (after destinations succeed). - Args: - batch: Action batch - sync_context: Sync context - runtime: Sync runtime - Raises: SyncFailureError: If postgres handler fails """ @@ -237,12 +190,6 @@ async def _dispatch_to_handler( ) -> None: """Dispatch to single handler with error wrapping. - Args: - handler: Handler to dispatch to - batch: Action batch - sync_context: Sync context - runtime: Sync runtime - Raises: SyncFailureError: If handler fails """ @@ -264,11 +211,6 @@ async def _dispatch_orphan_to_handler( ) -> None: """Dispatch orphan cleanup to single handler. - Args: - handler: Handler to dispatch to - orphan_entity_ids: Entity IDs to clean up - sync_context: Sync context - Raises: SyncFailureError: If handler fails """ diff --git a/backend/airweave/platform/sync/actions/entity/resolver.py b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py similarity index 78% rename from backend/airweave/platform/sync/actions/entity/resolver.py rename to backend/airweave/domains/sync_pipeline/entity_action_resolver.py index 179addb6a..c6c25793d 100644 --- a/backend/airweave/platform/sync/actions/entity/resolver.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py @@ -7,8 +7,9 @@ import time from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from airweave import crud, models +from airweave import models from airweave.db.session import get_db_context +from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.platform.entities._base import BaseEntity, DeletionEntity from airweave.platform.sync.actions.entity.types import ( EntityActionBatch, @@ -30,13 +31,14 @@ class EntityActionResolver: what operation is needed for each entity. """ - def __init__(self, entity_map: Dict[type, str]): - """Initialize the action resolver. - - Args: - entity_map: Mapping of entity class to entity_definition_short_name - """ + def __init__( + self, + entity_map: Dict[type, str], + entity_repo: EntityRepositoryProtocol, + ): + """Initialize with entity-type-to-short-name map and repository.""" self.entity_map = entity_map + self._entity_repo = entity_repo # ------------------------------------------------------------------------- # Public API @@ -49,17 +51,9 @@ async def resolve( ) -> EntityActionBatch: """Resolve entities to their appropriate actions. - Args: - entities: Entities to resolve (must have hash set in metadata) - sync_context: Sync context with logger - - Returns: - EntityActionBatch containing all resolved actions - Raises: SyncFailureError: If entity type not found in entity_map or missing hash """ - # Check if skip_hash_comparison is enabled if ( sync_context.execution_config and sync_context.execution_config.behavior @@ -70,17 +64,13 @@ async def resolve( ) return self._force_all_inserts(entities, sync_context) - # Step 1: Separate deletions from non-deletions delete_entities, non_delete_entities = self._separate_deletions(entities) - # Step 2: Build entity requests for DB lookup all_entities = non_delete_entities + delete_entities entity_requests = self._build_entity_requests(all_entities, sync_context) - # Step 3: Fetch existing entities from database existing_map = await self._fetch_existing_entities(entity_requests, sync_context) - # Step 4: Create actions for each entity batch = self._create_actions( non_delete_entities, delete_entities, @@ -88,29 +78,19 @@ async def resolve( sync_context, ) - # Log summary sync_context.logger.debug(f"Action resolution: {batch.summary()}") return batch def resolve_entity_definition_short_name(self, entity: BaseEntity) -> Optional[str]: - """Resolve entity definition short_name with polymorphic fallback. - - Args: - entity: Entity to resolve definition short_name for - - Returns: - Entity definition short_name, or None if not found - """ + """Resolve entity definition short_name with polymorphic fallback.""" entity_class = entity.__class__ - # Handle DeletionEntity - resolve to target class if issubclass(entity_class, DeletionEntity): target_class = getattr(entity_class, "deletes_entity_class", None) if target_class: entity_class = target_class - # Try direct lookup return self.entity_map.get(entity_class) # ------------------------------------------------------------------------- @@ -120,14 +100,7 @@ def resolve_entity_definition_short_name(self, entity: BaseEntity) -> Optional[s def _separate_deletions( self, entities: List[BaseEntity] ) -> Tuple[List[BaseEntity], List[BaseEntity]]: - """Separate deletion entities from non-deletions. - - Args: - entities: All entities to process - - Returns: - Tuple of (delete_entities, non_delete_entities) - """ + """Separate deletion entities from non-deletions.""" delete_entities = [] non_delete_entities = [] @@ -146,13 +119,6 @@ def _build_entity_requests( ) -> List[Tuple[str, str]]: """Build entity requests for database lookup. - Args: - entities: Entities to build requests for - sync_context: Sync context for logging - - Returns: - List of (entity_id, entity_definition_short_name) tuples - Raises: SyncFailureError: If entity type not found in entity_map """ @@ -176,13 +142,6 @@ async def _fetch_existing_entities( ) -> Dict[Tuple[str, str], models.Entity]: """Bulk fetch existing entity records from database. - Args: - entity_requests: List of (entity_id, entity_definition_short_name) tuples - sync_context: Sync context with logger - - Returns: - Dict mapping (entity_id, entity_definition_short_name) -> Entity model - Raises: SyncFailureError: If database lookup fails """ @@ -197,7 +156,7 @@ async def _fetch_existing_entities( ) async with get_db_context() as db: - existing_map = await crud.entity.bulk_get_by_entity_sync_and_definition( + existing_map = await self._entity_repo.bulk_get_by_entity_sync_and_definition( db, sync_id=sync_context.sync.id, entity_requests=entity_requests, @@ -224,15 +183,6 @@ def _create_actions( ) -> EntityActionBatch: """Create action objects for all entities. - Args: - non_delete_entities: Entities that are not deletions - delete_entities: DeletionEntity instances - existing_map: Map of existing DB records - sync_context: Sync context for error handling - - Returns: - EntityActionBatch with all resolved actions - Raises: SyncFailureError: If entity has no hash or type not in entity_map """ @@ -241,7 +191,6 @@ def _create_actions( keeps: List[EntityKeepAction] = [] deletes: List[EntityDeleteAction] = [] - # Process non-delete entities for entity in non_delete_entities: action = self._resolve_non_delete_action(entity, existing_map, sync_context) if isinstance(action, EntityInsertAction): @@ -251,7 +200,6 @@ def _create_actions( elif isinstance(action, EntityKeepAction): keeps.append(action) - # Process delete entities for entity in delete_entities: action = self._create_delete_action(entity, existing_map, sync_context) deletes.append(action) @@ -272,14 +220,6 @@ def _resolve_non_delete_action( ) -> EntityInsertAction | EntityUpdateAction | EntityKeepAction: """Resolve a non-delete entity to its action type. - Args: - entity: Entity to resolve - existing_map: Map of existing DB records - sync_context: Sync context - - Returns: - EntityInsertAction, EntityUpdateAction, or EntityKeepAction - Raises: SyncFailureError: If entity has no hash or type not in entity_map """ @@ -324,14 +264,6 @@ def _create_delete_action( ) -> EntityDeleteAction: """Create a delete action for a DeletionEntity. - Args: - entity: DeletionEntity to process - existing_map: Map of existing DB records - sync_context: Sync context - - Returns: - EntityDeleteAction with db_id if entity exists in DB - Raises: SyncFailureError: If entity type not in entity_map """ @@ -353,17 +285,7 @@ def _force_all_inserts( entities: List[BaseEntity], sync_context: "SyncContext", ) -> EntityActionBatch: - """Force all entities as INSERT actions (skip hash comparison). - - Used for ARF replay or when execution_config.behavior.skip_hash_comparison is True. - - Args: - entities: Entities to process - sync_context: Sync context - - Returns: - EntityActionBatch with all entities as inserts - """ + """Force all entities as INSERT actions (skip hash comparison).""" inserts: List[EntityInsertAction] = [] deletes: List[EntityDeleteAction] = [] diff --git a/backend/airweave/platform/sync/entity_pipeline.py b/backend/airweave/domains/sync_pipeline/entity_pipeline.py similarity index 92% rename from backend/airweave/platform/sync/entity_pipeline.py rename to backend/airweave/domains/sync_pipeline/entity_pipeline.py index 0d97e19e7..53b29405f 100644 --- a/backend/airweave/platform/sync/entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/entity_pipeline.py @@ -17,14 +17,15 @@ from airweave.core.events.sync import EntityBatchProcessedEvent, TypeActionCounts from airweave.core.shared_models import AirweaveFieldFlag +from airweave.domains.entities.protocols import EntityRepositoryProtocol +from airweave.domains.sync_pipeline.protocols import ( + EntityActionDispatcherProtocol, + EntityActionResolverProtocol, +) from airweave.platform.contexts import SyncContext from airweave.platform.contexts.runtime import SyncRuntime from airweave.platform.entities._base import BaseEntity -from airweave.platform.sync.actions import ( - EntityActionBatch, - EntityActionDispatcher, - EntityActionResolver, -) +from airweave.platform.sync.actions.entity.types import EntityActionBatch from airweave.platform.sync.exceptions import SyncFailureError from airweave.platform.sync.pipeline.cleanup_service import cleanup_service from airweave.platform.sync.pipeline.entity_tracker import EntityTracker @@ -44,21 +45,16 @@ def __init__( self, entity_tracker: EntityTracker, event_bus: "EventBus", - action_resolver: EntityActionResolver, - action_dispatcher: EntityActionDispatcher, + action_resolver: EntityActionResolverProtocol, + action_dispatcher: EntityActionDispatcherProtocol, + entity_repo: EntityRepositoryProtocol, ): - """Initialize pipeline with injected dependencies. - - Args: - entity_tracker: Centralized entity state tracker - event_bus: Per-sync event bus for EntityBatchProcessedEvent fan-out - action_resolver: Resolves entities to actions - action_dispatcher: Dispatches actions to handlers - """ + """Initialize with per-sync tracker, event bus, and action components.""" self._tracker = entity_tracker self._event_bus = event_bus self._resolver = action_resolver self._dispatcher = action_dispatcher + self._entity_repo = entity_repo self._batch_seq = 0 # ------------------------------------------------------------------------- @@ -71,13 +67,7 @@ async def process( sync_context: SyncContext, runtime: SyncRuntime, ) -> None: - """Process a batch of entities through the full pipeline. - - Args: - entities: Entities to process - sync_context: Sync context (frozen data) - runtime: Sync runtime (live services) - """ + """Process a batch of entities through the full pipeline.""" batch_start = time.monotonic() unique_entities = await self._track_and_dedupe(entities, sync_context) @@ -103,12 +93,7 @@ async def process( async def cleanup_orphaned_entities( self, sync_context: SyncContext, runtime: SyncRuntime ) -> None: - """Remove entities from database/destinations that were not encountered during sync. - - Args: - sync_context: Sync context - runtime: Sync runtime - """ + """Remove entities from database/destinations that were not encountered during sync.""" orphans_by_definition = await self._identify_orphans(sync_context) if not orphans_by_definition: return @@ -123,12 +108,7 @@ async def cleanup_orphaned_entities( await self._tracker.record_deletes(definition_id, len(entity_ids)) async def cleanup_temp_files(self, sync_context: SyncContext, runtime: SyncRuntime) -> None: - """Remove entire sync_job_id directory (final cleanup safety net). - - Args: - sync_context: Sync context - runtime: Sync runtime (provides source.file_downloader) - """ + """Remove entire sync_job_id directory (final cleanup safety net).""" await cleanup_service.cleanup_temp_files(sync_context, runtime) # ------------------------------------------------------------------------- @@ -258,7 +238,6 @@ async def _emit_batch_event( type_breakdown = self._build_type_breakdown(batch) - # TODO: wrap this into a class or similar skip_guardrails = ( sync_context.execution_config and sync_context.execution_config.behavior @@ -309,13 +288,14 @@ def _build_type_breakdown(batch: EntityActionBatch) -> Dict[str, TypeActionCount async def _identify_orphans(self, sync_context: SyncContext) -> Dict[str, List[str]]: """Identify orphaned entity IDs (in DB but not encountered), grouped by definition.""" - from airweave import crud from airweave.db.session import get_db_context encountered_ids = self._tracker.get_all_encountered_ids_flat() async with get_db_context() as db: - stored_entities = await crud.entity.get_by_sync_id(db=db, sync_id=sync_context.sync.id) + stored_entities = await self._entity_repo.get_by_sync_id( + db=db, sync_id=sync_context.sync.id + ) orphans_by_definition: Dict[str, List[str]] = defaultdict(list) for entity in stored_entities: diff --git a/backend/airweave/platform/sync/factory.py b/backend/airweave/domains/sync_pipeline/factory.py similarity index 79% rename from backend/airweave/platform/sync/factory.py rename to backend/airweave/domains/sync_pipeline/factory.py index f80116ec2..bf4a6a54d 100644 --- a/backend/airweave/platform/sync/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -6,6 +6,8 @@ 3. Building per-sync event emitter with subscribers (progress relay, billing) 4. Assembling SyncRuntime from the services 5. Wiring everything into SyncOrchestrator + +Instance-based with injected deps (code blue architecture). """ import asyncio @@ -15,10 +17,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas -from airweave.core import container as container_mod # [code blue] todo from airweave.core.context import BaseContext from airweave.core.logging import LoggerConfigurator, logger +from airweave.core.protocols.event_bus import EventBus from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol +from airweave.domains.entities.protocols import EntityRepositoryProtocol +from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol +from airweave.domains.usage.protocols import UsageLimitCheckerProtocol from airweave.platform.builders import SyncContextBuilder from airweave.platform.builders.tracking import TrackingContextBuilder from airweave.platform.contexts.runtime import SyncRuntime @@ -26,11 +31,9 @@ from airweave.platform.sync.actions import ( ACActionDispatcher, ACActionResolver, - EntityActionResolver, EntityDispatcherBuilder, ) from airweave.platform.sync.config import SyncConfig, SyncConfigBuilder -from airweave.platform.sync.entity_pipeline import EntityPipeline from airweave.platform.sync.handlers import ACPostgresHandler from airweave.platform.sync.orchestrator import SyncOrchestrator from airweave.platform.sync.pipeline.acl_membership_tracker import ACLMembershipTracker @@ -38,6 +41,9 @@ from airweave.platform.sync.stream import AsyncSourceStream from airweave.platform.sync.worker_pool import AsyncWorkerPool +from .entity_action_resolver import EntityActionResolver +from .entity_pipeline import EntityPipeline + class SyncFactory: """Factory for sync orchestrator. @@ -46,18 +52,31 @@ class SyncFactory: into the orchestrator and pipeline components. """ - @classmethod + def __init__( + self, + sc_repo: SourceConnectionRepositoryProtocol, + event_bus: EventBus, + usage_checker: UsageLimitCheckerProtocol, + dense_embedder: DenseEmbedderProtocol, + sparse_embedder: SparseEmbedderProtocol, + entity_repo: EntityRepositoryProtocol, + ) -> None: + """Initialize with all deployment-wide dependencies.""" + self._sc_repo = sc_repo + self._event_bus = event_bus + self._usage_checker = usage_checker + self._dense_embedder = dense_embedder + self._sparse_embedder = sparse_embedder + self._entity_repo = entity_repo + async def create_orchestrator( - cls, + self, db: AsyncSession, sync: schemas.Sync, sync_job: schemas.SyncJob, collection: schemas.CollectionRecord, connection: schemas.Connection, ctx: BaseContext, - dense_embedder: DenseEmbedderProtocol, - sparse_embedder: SparseEmbedderProtocol, - access_token: Optional[str] = None, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, ) -> SyncOrchestrator: @@ -65,7 +84,6 @@ async def create_orchestrator( init_start = time.time() logger.info("Creating sync orchestrator...") - # Step 0: Build layered sync configuration resolved_config = SyncConfigBuilder.build( collection_overrides=collection.sync_config, sync_overrides=sync.sync_config, @@ -76,28 +94,31 @@ async def create_orchestrator( f"destinations={resolved_config.destinations.model_dump()}" ) - # Step 1: Get source connection ID (needed before parallel build) - source_connection_id = await SyncContextBuilder.get_source_connection_id(db, sync, ctx) + # Direct repo call — replaces SyncContextBuilder -> SourceContextBuilder chain + sc = await self._sc_repo.get_by_sync_id(db, sync_id=sync.id, ctx=ctx) + if not sc: + from airweave.core.exceptions import NotFoundException + + raise NotFoundException(f"Source connection record not found for sync {sync.id}") + source_connection_id = sc.id - # Step 2: Build services in parallel source_result, destinations_result, entity_tracker_result = await asyncio.gather( - cls._build_source( + self._build_source( db=db, sync=sync, sync_job=sync_job, ctx=ctx, - access_token=access_token, force_full_sync=force_full_sync, execution_config=resolved_config, ), - cls._build_destinations( + self._build_destinations( db=db, sync=sync, collection=collection, ctx=ctx, execution_config=resolved_config, ), - cls._build_tracking( + self._build_tracking( db=db, sync=sync, sync_job=sync_job, @@ -108,7 +129,6 @@ async def create_orchestrator( source, cursor = source_result destinations, entity_map = destinations_result - # Step 3: Build SyncContext (data only) sync_context = await SyncContextBuilder.build( db=db, sync=sync, @@ -123,34 +143,36 @@ async def create_orchestrator( execution_config=resolved_config, ) - # Step 4: Assemble SyncRuntime (live services) runtime = SyncRuntime( source=source, cursor=cursor, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder, + dense_embedder=self._dense_embedder, + sparse_embedder=self._sparse_embedder, destinations=destinations, entity_tracker=entity_tracker_result, - event_bus=container_mod.container.event_bus, - usage_checker=container_mod.container.usage_checker, + event_bus=self._event_bus, + usage_checker=self._usage_checker, ) logger.debug(f"Context + runtime built in {time.time() - init_start:.2f}s") - # Step 6: Build pipelines using runtime services dispatcher = EntityDispatcherBuilder.build( destinations=runtime.destinations, execution_config=resolved_config, logger=sync_context.logger, ) - action_resolver = EntityActionResolver(entity_map=sync_context.entity_map) + action_resolver = EntityActionResolver( + entity_map=sync_context.entity_map, + entity_repo=self._entity_repo, + ) entity_pipeline = EntityPipeline( entity_tracker=runtime.entity_tracker, - event_bus=container_mod.container.event_bus, + event_bus=self._event_bus, action_resolver=action_resolver, action_dispatcher=dispatcher, + entity_repo=self._entity_repo, ) access_control_pipeline = AccessControlPipeline( @@ -171,7 +193,6 @@ async def create_orchestrator( logger=sync_context.logger, ) - # Step 7: Create orchestrator orchestrator = SyncOrchestrator( entity_pipeline=entity_pipeline, worker_pool=worker_pool, @@ -188,12 +209,9 @@ async def create_orchestrator( # Private: Service builders (delegate to sub-builders) # ------------------------------------------------------------------------- - @classmethod - async def _build_source( - cls, db, sync, sync_job, ctx, access_token, force_full_sync, execution_config - ): + @staticmethod + async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_config): """Build source and cursor. Returns (source, cursor) tuple.""" - from airweave.core.logging import LoggerConfigurator from airweave.platform.builders.source import SourceContextBuilder from airweave.platform.contexts.infra import InfraContext @@ -211,16 +229,14 @@ async def _build_source( sync=sync, sync_job=sync_job, infra=infra, - access_token=access_token, force_full_sync=force_full_sync, execution_config=execution_config, ) return source_ctx.source, source_ctx.cursor - @classmethod - async def _build_destinations(cls, db, sync, collection, ctx, execution_config): + @staticmethod + async def _build_destinations(db, sync, collection, ctx, execution_config): """Build destinations and entity map. Returns (destinations, entity_map) tuple.""" - from airweave.core.logging import LoggerConfigurator from airweave.platform.builders.destinations import DestinationsContextBuilder dest_logger = LoggerConfigurator.configure_logger( @@ -240,9 +256,8 @@ async def _build_destinations(cls, db, sync, collection, ctx, execution_config): execution_config=execution_config, ) - @classmethod + @staticmethod async def _build_tracking( - cls, db: AsyncSession, sync: schemas.Sync, sync_job: schemas.SyncJob, diff --git a/backend/airweave/domains/sync_pipeline/fakes/__init__.py b/backend/airweave/domains/sync_pipeline/fakes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py new file mode 100644 index 000000000..ee7bfc62d --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py @@ -0,0 +1,27 @@ +"""Fake entity repository for testing.""" + +from typing import Dict, List, Tuple +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave.models.entity import Entity + + +class FakeEntityRepository: + """In-memory fake for EntityRepositoryProtocol.""" + + def __init__(self) -> None: + self._entities: List[Entity] = [] + + async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: + return [e for e in self._entities if e.sync_id == sync_id] + + async def bulk_get_by_entity_sync_and_definition( + self, + db: AsyncSession, + *, + sync_id: UUID, + entity_requests: list[Tuple[str, str]], + ) -> Dict[Tuple[str, str], Entity]: + return {} diff --git a/backend/airweave/domains/sync_pipeline/fakes/factory.py b/backend/airweave/domains/sync_pipeline/fakes/factory.py new file mode 100644 index 000000000..011e5c5d7 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/fakes/factory.py @@ -0,0 +1,32 @@ +"""Fake sync factory for testing.""" + +from typing import Optional +from unittest.mock import AsyncMock + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import schemas +from airweave.core.context import BaseContext +from airweave.platform.sync.config import SyncConfig + + +class FakeSyncFactory: + """In-memory fake for SyncFactoryProtocol.""" + + def __init__(self) -> None: + self._calls: list[tuple] = [] + self._orchestrator = AsyncMock() + + async def create_orchestrator( + self, + db: AsyncSession, + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: BaseContext, + force_full_sync: bool = False, + execution_config: Optional[SyncConfig] = None, + ): + self._calls.append(("create_orchestrator", sync.id, sync_job.id)) + return self._orchestrator diff --git a/backend/airweave/domains/sync_pipeline/protocols.py b/backend/airweave/domains/sync_pipeline/protocols.py new file mode 100644 index 000000000..5e11fdf08 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/protocols.py @@ -0,0 +1,109 @@ +"""Protocols for the sync pipeline domain.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Protocol + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import schemas +from airweave.platform.entities._base import BaseEntity +from airweave.platform.sync.actions.entity.types import EntityActionBatch + +if TYPE_CHECKING: + from airweave.core.context import BaseContext + from airweave.platform.contexts import SyncContext + from airweave.platform.contexts.runtime import SyncRuntime + from airweave.platform.sync.config import SyncConfig + from airweave.platform.sync.orchestrator import SyncOrchestrator + + +class ChunkEmbedProcessorProtocol(Protocol): + """Chunks text and computes dense/sparse embeddings.""" + + async def process( + self, + entities: List[BaseEntity], + sync_context: SyncContext, + runtime: SyncRuntime, + ) -> List[BaseEntity]: + """Chunk text and compute embeddings for entities.""" + ... + + +class EntityActionResolverProtocol(Protocol): + """Resolves entities to INSERT/UPDATE/DELETE/KEEP actions.""" + + async def resolve( + self, + entities: List[BaseEntity], + sync_context: SyncContext, + ) -> EntityActionBatch: + """Compare entity hashes and determine needed actions.""" + ... + + def resolve_entity_definition_short_name(self, entity: BaseEntity) -> Optional[str]: + """Return the short name for an entity's definition, if mapped.""" + ... + + +class EntityActionDispatcherProtocol(Protocol): + """Dispatches resolved entity actions to handlers.""" + + async def dispatch( + self, + batch: EntityActionBatch, + sync_context: SyncContext, + runtime: SyncRuntime, + ) -> None: + """Execute a batch of entity actions against all handlers.""" + ... + + async def dispatch_orphan_cleanup( + self, + orphan_entity_ids: List[str], + sync_context: SyncContext, + ) -> None: + """Delete orphaned entities from all handlers.""" + ... + + +class EntityPipelineProtocol(Protocol): + """Orchestrates entity processing through sync stages.""" + + async def process( + self, + entities: List[BaseEntity], + sync_context: SyncContext, + runtime: SyncRuntime, + ) -> None: + """Process a batch of entities through the full pipeline.""" + ... + + async def cleanup_orphaned_entities( + self, sync_context: SyncContext, runtime: SyncRuntime + ) -> None: + """Remove entities no longer present in the source.""" + ... + + async def cleanup_temp_files(self, sync_context: SyncContext, runtime: SyncRuntime) -> None: + """Clean up temporary files created during the sync.""" + ... + + +class SyncFactoryProtocol(Protocol): + """Builds a SyncOrchestrator for a given sync run.""" + + async def create_orchestrator( + self, + db: AsyncSession, + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: BaseContext, + force_full_sync: bool = False, + execution_config: Optional[SyncConfig] = None, + ) -> SyncOrchestrator: + """Create and return a fully-wired SyncOrchestrator.""" + ... diff --git a/backend/airweave/domains/sync_pipeline/tests/__init__.py b/backend/airweave/domains/sync_pipeline/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/domains/syncs/fakes/sync_service.py b/backend/airweave/domains/syncs/fakes/sync_service.py index 3736d0972..b0bacfd02 100644 --- a/backend/airweave/domains/syncs/fakes/sync_service.py +++ b/backend/airweave/domains/syncs/fakes/sync_service.py @@ -4,7 +4,6 @@ from airweave import schemas from airweave.api.context import ApiContext -from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.platform.sync.config import SyncConfig @@ -12,7 +11,6 @@ class FakeSyncService: """In-memory fake for SyncServiceProtocol.""" def __init__(self) -> None: - """Initialize with empty call log.""" self._calls: list[tuple] = [] async def run( @@ -22,12 +20,8 @@ async def run( collection: schemas.CollectionRecord, source_connection: schemas.Connection, ctx: ApiContext, - dense_embedder: DenseEmbedderProtocol, - sparse_embedder: SparseEmbedderProtocol, - access_token: Optional[str] = None, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, ) -> schemas.Sync: - """Record call and return the sync as-is.""" self._calls.append(("run", sync, sync_job)) return sync diff --git a/backend/airweave/domains/syncs/protocols.py b/backend/airweave/domains/syncs/protocols.py index f0d7e2483..cb33446e0 100644 --- a/backend/airweave/domains/syncs/protocols.py +++ b/backend/airweave/domains/syncs/protocols.py @@ -10,7 +10,6 @@ from airweave.api.context import ApiContext from airweave.core.shared_models import SyncJobStatus from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.syncs.types import SyncProvisionResult from airweave.models.sync import Sync @@ -179,9 +178,6 @@ async def run( collection: schemas.CollectionRecord, source_connection: schemas.Connection, ctx: ApiContext, - dense_embedder: DenseEmbedderProtocol, - sparse_embedder: SparseEmbedderProtocol, - access_token: Optional[str] = None, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, ) -> schemas.Sync: diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index cff119f8d..4e4889208 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -10,10 +10,9 @@ from airweave.core.datetime_utils import utc_now_naive from airweave.core.shared_models import SyncJobStatus from airweave.db.session import get_db_context -from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol +from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol from airweave.domains.syncs.protocols import SyncJobServiceProtocol, SyncServiceProtocol from airweave.platform.sync.config import SyncConfig -from airweave.platform.sync.factory import SyncFactory class SyncService(SyncServiceProtocol): @@ -22,9 +21,14 @@ class SyncService(SyncServiceProtocol): Stateless — the only production caller is RunSyncActivity. """ - def __init__(self, sync_job_service: SyncJobServiceProtocol) -> None: - """Initialize with injected sync job service.""" + def __init__( + self, + sync_job_service: SyncJobServiceProtocol, + sync_factory: SyncFactoryProtocol, + ) -> None: + """Initialize with job service and factory dependencies.""" self._sync_job_service = sync_job_service + self._sync_factory = sync_factory async def run( self, @@ -33,43 +37,21 @@ async def run( collection: schemas.CollectionRecord, source_connection: schemas.Connection, ctx: ApiContext, - dense_embedder: DenseEmbedderProtocol, - sparse_embedder: SparseEmbedderProtocol, - access_token: Optional[str] = None, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, ) -> schemas.Sync: - """Run a sync. - - Args: - sync: The sync to run. - sync_job: The sync job to run. - collection: The collection to sync. - source_connection: The source connection to sync. - ctx: The API context. - dense_embedder: Domain dense embedder instance. - sparse_embedder: Domain sparse embedder instance. - access_token: Optional access token instead of stored credentials. - force_full_sync: If True, forces a full sync with orphaned entity deletion. - execution_config: Optional execution config for sync behavior. - - Returns: - The sync. - """ + """Run a sync.""" try: async with get_db_context() as db: - orchestrator = await SyncFactory.create_orchestrator( + orchestrator = await self._sync_factory.create_orchestrator( db=db, sync=sync, sync_job=sync_job, collection=collection, connection=source_connection, ctx=ctx, - access_token=access_token, force_full_sync=force_full_sync, execution_config=execution_config, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder, ) except Exception as e: ctx.logger.error(f"Error during sync orchestrator creation: {e}") diff --git a/backend/airweave/domains/syncs/tests/test_sync_service.py b/backend/airweave/domains/syncs/tests/test_sync_service.py index 689676848..0912eb088 100644 --- a/backend/airweave/domains/syncs/tests/test_sync_service.py +++ b/backend/airweave/domains/syncs/tests/test_sync_service.py @@ -17,7 +17,6 @@ def _mock_ctx(): - """Minimal mock that satisfies ApiContext duck-typing.""" ctx = MagicMock() ctx.organization = MagicMock() ctx.organization.id = uuid4() @@ -52,7 +51,6 @@ class RunCase: expect_raises: bool = False def __post_init__(self): - """Default orchestrator_result to a MagicMock when no factory error.""" if self.orchestrator_result is None and self.factory_error is None: self.orchestrator_result = MagicMock() @@ -80,41 +78,39 @@ def __post_init__(self): @pytest.mark.parametrize("case", RUN_CASES, ids=lambda c: c.name) async def test_run(case: RunCase): fake_job_svc = FakeSyncJobService() - svc = SyncService(sync_job_service=fake_job_svc) + fake_factory = MagicMock() + + mock_orchestrator = MagicMock() + mock_orchestrator.run = AsyncMock(return_value=case.orchestrator_result) + + if case.factory_error: + fake_factory.create_orchestrator = AsyncMock( + side_effect=case.factory_error, + ) + else: + fake_factory.create_orchestrator = AsyncMock( + return_value=mock_orchestrator, + ) + + svc = SyncService( + sync_job_service=fake_job_svc, + sync_factory=fake_factory, + ) sync = _mock_sync() sync_job = _mock_sync_job() collection = MagicMock() source_connection = MagicMock() ctx = _mock_ctx() - dense_embedder = MagicMock() - sparse_embedder = MagicMock() - - mock_orchestrator = MagicMock() - mock_orchestrator.run = AsyncMock(return_value=case.orchestrator_result) mock_db = AsyncMock() - with ( - patch( - "airweave.domains.syncs.service.get_db_context", - ) as mock_db_ctx, - patch( - "airweave.domains.syncs.service.SyncFactory", - ) as mock_factory_cls, - ): + with patch( + "airweave.domains.syncs.service.get_db_context", + ) as mock_db_ctx: mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - if case.factory_error: - mock_factory_cls.create_orchestrator = AsyncMock( - side_effect=case.factory_error, - ) - else: - mock_factory_cls.create_orchestrator = AsyncMock( - return_value=mock_orchestrator, - ) - if case.expect_raises: with pytest.raises(type(case.factory_error)): await svc.run( @@ -123,8 +119,6 @@ async def test_run(case: RunCase): collection=collection, source_connection=source_connection, ctx=ctx, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder, ) else: result = await svc.run( @@ -133,8 +127,6 @@ async def test_run(case: RunCase): collection=collection, source_connection=source_connection, ctx=ctx, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder, ) assert result is case.orchestrator_result mock_orchestrator.run.assert_awaited_once() @@ -157,29 +149,29 @@ async def test_run(case: RunCase): @pytest.mark.asyncio async def test_run_forwards_optional_kwargs(): - """access_token, force_full_sync, execution_config reach the factory.""" + """force_full_sync, execution_config reach the factory.""" fake_job_svc = FakeSyncJobService() - svc = SyncService(sync_job_service=fake_job_svc) + fake_factory = MagicMock() mock_orchestrator = MagicMock() mock_orchestrator.run = AsyncMock(return_value=_mock_sync()) + fake_factory.create_orchestrator = AsyncMock( + return_value=mock_orchestrator, + ) + + svc = SyncService( + sync_job_service=fake_job_svc, + sync_factory=fake_factory, + ) + mock_db = AsyncMock() + exec_config = MagicMock() - with ( - patch( - "airweave.domains.syncs.service.get_db_context", - ) as mock_db_ctx, - patch( - "airweave.domains.syncs.service.SyncFactory", - ) as mock_factory_cls, - ): + with patch( + "airweave.domains.syncs.service.get_db_context", + ) as mock_db_ctx: mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - mock_factory_cls.create_orchestrator = AsyncMock( - return_value=mock_orchestrator, - ) - - exec_config = MagicMock() await svc.run( sync=_mock_sync(), @@ -187,15 +179,11 @@ async def test_run_forwards_optional_kwargs(): collection=MagicMock(), source_connection=MagicMock(), ctx=_mock_ctx(), - dense_embedder=MagicMock(), - sparse_embedder=MagicMock(), - access_token="tok-123", force_full_sync=True, execution_config=exec_config, ) - _, kwargs = mock_factory_cls.create_orchestrator.call_args - assert kwargs["access_token"] == "tok-123" + _, kwargs = fake_factory.create_orchestrator.call_args assert kwargs["force_full_sync"] is True assert kwargs["execution_config"] is exec_config @@ -205,7 +193,9 @@ async def test_run_forwards_optional_kwargs(): # --------------------------------------------------------------------------- -def test_stores_sync_job_service(): - fake = FakeSyncJobService() - svc = SyncService(sync_job_service=fake) - assert svc._sync_job_service is fake +def test_stores_injected_deps(): + fake_job = FakeSyncJobService() + fake_factory = MagicMock() + svc = SyncService(sync_job_service=fake_job, sync_factory=fake_factory) + assert svc._sync_job_service is fake_job + assert svc._sync_factory is fake_factory diff --git a/backend/airweave/platform/builders/source.py b/backend/airweave/platform/builders/source.py index 3e3447328..b8590a244 100644 --- a/backend/airweave/platform/builders/source.py +++ b/backend/airweave/platform/builders/source.py @@ -193,30 +193,6 @@ async def _build_arf_replay_context( return SourceContext(source=source, cursor=cursor) - @classmethod - async def get_source_connection_id( - cls, - db: AsyncSession, - sync: schemas.Sync, - ctx: ApiContext, - ) -> UUID: - """Get user-facing source connection ID for logging and scoping. - - Args: - db: Database session - sync: Sync configuration - ctx: API context - - Returns: - User-facing SourceConnection UUID (not internal Connection ID). - """ - source_connection_obj = await crud.source_connection.get_by_sync_id( - db, sync_id=sync.id, ctx=ctx - ) - if not source_connection_obj: - raise NotFoundException(f"Source connection record not found for sync {sync.id}") - return UUID(str(source_connection_obj.id)) - # ------------------------------------------------------------------------- # Private helpers # ------------------------------------------------------------------------- diff --git a/backend/airweave/platform/builders/sync.py b/backend/airweave/platform/builders/sync.py index b9801b29f..d1db27996 100644 --- a/backend/airweave/platform/builders/sync.py +++ b/backend/airweave/platform/builders/sync.py @@ -102,15 +102,3 @@ def _build_logger( "scheduled": str(sync_job.scheduled), }, ) - - @classmethod - async def get_source_connection_id( - cls, - db: AsyncSession, - sync: schemas.Sync, - ctx: BaseContext, - ) -> UUID: - """Get user-facing source connection ID for logging and scoping.""" - from airweave.platform.builders.source import SourceContextBuilder - - return await SourceContextBuilder.get_source_connection_id(db, sync, ctx) diff --git a/backend/airweave/platform/sync/actions/entity/__init__.py b/backend/airweave/platform/sync/actions/entity/__init__.py index 6f0257ef7..caa8ab9a1 100644 --- a/backend/airweave/platform/sync/actions/entity/__init__.py +++ b/backend/airweave/platform/sync/actions/entity/__init__.py @@ -3,9 +3,9 @@ Entity-specific action pipeline for sync operations. """ +from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher +from airweave.domains.sync_pipeline.entity_action_resolver import EntityActionResolver from airweave.platform.sync.actions.entity.builder import EntityDispatcherBuilder -from airweave.platform.sync.actions.entity.dispatcher import EntityActionDispatcher -from airweave.platform.sync.actions.entity.resolver import EntityActionResolver from airweave.platform.sync.actions.entity.types import ( EntityActionBatch, EntityDeleteAction, diff --git a/backend/airweave/platform/sync/actions/entity/builder.py b/backend/airweave/platform/sync/actions/entity/builder.py index 95731f560..225d32ca3 100644 --- a/backend/airweave/platform/sync/actions/entity/builder.py +++ b/backend/airweave/platform/sync/actions/entity/builder.py @@ -3,8 +3,8 @@ from typing import List, Optional from airweave.core.logging import ContextualLogger +from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher from airweave.platform.destinations._base import BaseDestination -from airweave.platform.sync.actions.entity.dispatcher import EntityActionDispatcher from airweave.platform.sync.config import SyncConfig from airweave.platform.sync.handlers.arf import ArfHandler from airweave.platform.sync.handlers.destination import DestinationHandler diff --git a/backend/airweave/platform/sync/orchestrator.py b/backend/airweave/platform/sync/orchestrator.py index eb29d8932..02c98be40 100644 --- a/backend/airweave/platform/sync/orchestrator.py +++ b/backend/airweave/platform/sync/orchestrator.py @@ -12,6 +12,7 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.core.sync_job_service import sync_job_service from airweave.db.session import get_db_context +from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline from airweave.domains.usage.exceptions import ( PaymentRequiredError, UsageLimitExceededError, @@ -20,7 +21,6 @@ from airweave.platform.contexts import SyncContext from airweave.platform.contexts.runtime import SyncRuntime from airweave.platform.sync.access_control_pipeline import AccessControlPipeline -from airweave.platform.sync.entity_pipeline import EntityPipeline from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError from airweave.platform.sync.stream import AsyncSourceStream from airweave.platform.sync.worker_pool import AsyncWorkerPool diff --git a/backend/airweave/platform/temporal/activities/sync.py b/backend/airweave/platform/temporal/activities/sync.py index 0451dd54a..da2f93372 100644 --- a/backend/airweave/platform/temporal/activities/sync.py +++ b/backend/airweave/platform/temporal/activities/sync.py @@ -28,7 +28,6 @@ from airweave.core.redis_client import redis_client from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol -from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol from airweave.domains.syncs.protocols import ( SyncJobRepositoryProtocol, @@ -61,8 +60,6 @@ class RunSyncActivity: """ event_bus: EventBus - dense_embedder: DenseEmbedderProtocol - sparse_embedder: SparseEmbedderProtocol sync_service: SyncServiceProtocol sync_job_service: SyncJobServiceProtocol collection_repo: CollectionRepositoryProtocol @@ -444,11 +441,8 @@ async def _run_sync_task( collection=collection, source_connection=connection, ctx=ctx, - access_token=access_token, force_full_sync=force_full_sync, execution_config=execution_config, - dense_embedder=self.dense_embedder, - sparse_embedder=self.sparse_embedder, ) except NotFoundException as e: if "Source connection record not found" in str(e) or "Connection not found" in str(e): diff --git a/backend/airweave/platform/temporal/worker/wiring.py b/backend/airweave/platform/temporal/worker/wiring.py index 522e861f4..f452d7433 100644 --- a/backend/airweave/platform/temporal/worker/wiring.py +++ b/backend/airweave/platform/temporal/worker/wiring.py @@ -30,8 +30,6 @@ def create_activities() -> list: ) event_bus = container.event_bus - dense_embedder = container.dense_embedder - sparse_embedder = container.sparse_embedder email_service = container.email_service sync_service = container.sync_service sync_job_service = container.sync_job_service @@ -48,8 +46,6 @@ def create_activities() -> list: return [ RunSyncActivity( event_bus=event_bus, - dense_embedder=dense_embedder, - sparse_embedder=sparse_embedder, sync_service=sync_service, sync_job_service=sync_job_service, collection_repo=collection_repo, diff --git a/backend/conftest.py b/backend/conftest.py index d917d4924..cfd17d47f 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -385,6 +385,22 @@ def fake_sync_lifecycle(): return FakeSyncLifecycleService() +@pytest.fixture +def fake_sync_factory(): + """Fake SyncFactory.""" + from airweave.domains.sync_pipeline.fakes.factory import FakeSyncFactory + + return FakeSyncFactory() + + +@pytest.fixture +def fake_entity_repo(): + """Fake EntityRepository.""" + from airweave.domains.sync_pipeline.fakes.entity_repository import FakeEntityRepository + + return FakeEntityRepository() + + @pytest.fixture def fake_billing_webhook(): """Fake BillingWebhookProcessor.""" @@ -625,6 +641,8 @@ def test_container( fake_connect_service, fake_browse_tree_service, fake_selection_repo, + fake_sync_factory, + fake_entity_repo, ): """A Container with all dependencies replaced by fakes. @@ -694,4 +712,6 @@ def test_container( organization_service=fake_organization_service, email_service=fake_email_service, user_service=fake_user_service, + sync_factory=fake_sync_factory, + entity_repo=fake_entity_repo, ) From fd3fc362e28739b1dfce1b09da849832fbb77dad Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 16:21:01 -0700 Subject: [PATCH 02/13] fix: add missing docstring to __post_init__ in test_sync_service --- backend/airweave/domains/syncs/tests/test_sync_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/airweave/domains/syncs/tests/test_sync_service.py b/backend/airweave/domains/syncs/tests/test_sync_service.py index 0912eb088..322da47ff 100644 --- a/backend/airweave/domains/syncs/tests/test_sync_service.py +++ b/backend/airweave/domains/syncs/tests/test_sync_service.py @@ -51,6 +51,7 @@ class RunCase: expect_raises: bool = False def __post_init__(self): + """Default orchestrator_result to a MagicMock when no factory error.""" if self.orchestrator_result is None and self.factory_error is None: self.orchestrator_result = MagicMock() From 19b1a41ab19072044bf689768be88d99f9d3218f Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 16:29:42 -0700 Subject: [PATCH 03/13] test: add sync pipeline unit tests and fix circular import Add 15 tests for SyncFactory, EntityActionResolver, EntityPipeline covering DI wiring, action resolution (INSERT/UPDATE/KEEP), and orphan identification. Fix circular import in actions/__init__.py and actions/entity/__init__.py by converting eager re-exports to lazy __getattr__. --- .../tests/test_entity_action_resolver.py | 194 ++++++++++++++++++ .../tests/test_entity_pipeline.py | 120 +++++++++++ .../sync_pipeline/tests/test_factory.py | 155 ++++++++++++++ .../platform/sync/actions/__init__.py | 26 +-- .../platform/sync/actions/entity/__init__.py | 59 ++++-- 5 files changed, 530 insertions(+), 24 deletions(-) create mode 100644 backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py create mode 100644 backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py create mode 100644 backend/airweave/domains/sync_pipeline/tests/test_factory.py diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py new file mode 100644 index 000000000..68ae540d6 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py @@ -0,0 +1,194 @@ +"""Tests for EntityActionResolver — DI wiring and action resolution.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from airweave.domains.sync_pipeline.entity_action_resolver import EntityActionResolver +from airweave.platform.entities._airweave_field import AirweaveField +from airweave.platform.entities._base import ( + AirweaveSystemMetadata, + BaseEntity, + DeletionEntity, +) +from airweave.platform.sync.actions.entity.types import ( + EntityInsertAction, + EntityKeepAction, + EntityUpdateAction, +) +from airweave.platform.sync.exceptions import SyncFailureError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _StubEntity(BaseEntity): + """Minimal concrete entity for testing.""" + + stub_id: str = AirweaveField(..., is_entity_id=True) + stub_name: str = AirweaveField(..., is_name=True) + + +class _StubDeletion(DeletionEntity): + """Minimal deletion entity for testing.""" + + deletes_entity_class = _StubEntity + stub_id: str = AirweaveField(..., is_entity_id=True) + stub_name: str = AirweaveField(..., is_name=True) + + +def _entity(entity_id="e-1", hash_val="abc123"): + """Create a _StubEntity with entity_id set (normally done by pipeline enrichment).""" + e = _StubEntity(stub_id=entity_id, stub_name="test", breadcrumbs=[]) + e.entity_id = entity_id + e.airweave_system_metadata = AirweaveSystemMetadata(hash=hash_val) + return e + + +def _sync_context(): + ctx = MagicMock() + ctx.sync = MagicMock() + ctx.sync.id = uuid4() + ctx.logger = MagicMock() + ctx.execution_config = None + return ctx + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- + + +def test_constructor_stores_entity_repo(): + """entity_repo is stored on the instance.""" + repo = MagicMock() + entity_map = {_StubEntity: "stub"} + resolver = EntityActionResolver(entity_map=entity_map, entity_repo=repo) + assert resolver._entity_repo is repo + + +# --------------------------------------------------------------------------- +# resolve_entity_definition_short_name +# --------------------------------------------------------------------------- + + +def test_resolve_short_name_direct(): + """Direct class lookup returns short_name.""" + repo = MagicMock() + resolver = EntityActionResolver(entity_map={_StubEntity: "stub"}, entity_repo=repo) + e = _entity() + assert resolver.resolve_entity_definition_short_name(e) == "stub" + + +def test_resolve_short_name_deletion_entity(): + """DeletionEntity falls back to deletes_entity_class.""" + repo = MagicMock() + resolver = EntityActionResolver(entity_map={_StubEntity: "stub"}, entity_repo=repo) + d = _StubDeletion(stub_id="del-1", stub_name="del", deletion_status="removed", breadcrumbs=[]) + assert resolver.resolve_entity_definition_short_name(d) == "stub" + + +def test_resolve_short_name_unmapped(): + """Unknown entity type returns None.""" + repo = MagicMock() + resolver = EntityActionResolver(entity_map={}, entity_repo=repo) + e = _entity() + assert resolver.resolve_entity_definition_short_name(e) is None + + +# --------------------------------------------------------------------------- +# resolve — delegates to entity_repo +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_resolve_insert_when_no_existing(): + """New entity (not in DB) → INSERT action.""" + repo = MagicMock() + repo.bulk_get_by_entity_sync_and_definition = AsyncMock(return_value={}) + + resolver = EntityActionResolver(entity_map={_StubEntity: "stub"}, entity_repo=repo) + ctx = _sync_context() + e = _entity(entity_id="new-1", hash_val="h1") + + with patch("airweave.db.session.get_db_context") as mock_db_ctx: + mock_db = AsyncMock() + mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) + + batch = await resolver.resolve([e], ctx) + + assert len(batch.inserts) == 1 + assert isinstance(batch.inserts[0], EntityInsertAction) + assert batch.inserts[0].entity is e + repo.bulk_get_by_entity_sync_and_definition.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_resolve_update_when_hash_changed(): + """Existing entity with changed hash → UPDATE action.""" + db_entity = MagicMock() + db_entity.id = uuid4() + db_entity.hash = "old-hash" + + repo = MagicMock() + repo.bulk_get_by_entity_sync_and_definition = AsyncMock( + return_value={("e-1", "stub"): db_entity} + ) + + resolver = EntityActionResolver(entity_map={_StubEntity: "stub"}, entity_repo=repo) + ctx = _sync_context() + e = _entity(entity_id="e-1", hash_val="new-hash") + + with patch("airweave.db.session.get_db_context") as mock_db_ctx: + mock_db = AsyncMock() + mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) + + batch = await resolver.resolve([e], ctx) + + assert len(batch.updates) == 1 + assert isinstance(batch.updates[0], EntityUpdateAction) + assert batch.updates[0].db_id == db_entity.id + + +@pytest.mark.asyncio +async def test_resolve_keep_when_hash_matches(): + """Existing entity with same hash → KEEP action.""" + db_entity = MagicMock() + db_entity.id = uuid4() + db_entity.hash = "same-hash" + + repo = MagicMock() + repo.bulk_get_by_entity_sync_and_definition = AsyncMock( + return_value={("e-1", "stub"): db_entity} + ) + + resolver = EntityActionResolver(entity_map={_StubEntity: "stub"}, entity_repo=repo) + ctx = _sync_context() + e = _entity(entity_id="e-1", hash_val="same-hash") + + with patch("airweave.db.session.get_db_context") as mock_db_ctx: + mock_db = AsyncMock() + mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) + + batch = await resolver.resolve([e], ctx) + + assert len(batch.keeps) == 1 + assert isinstance(batch.keeps[0], EntityKeepAction) + + +@pytest.mark.asyncio +async def test_resolve_raises_on_unmapped_entity(): + """Entity type not in entity_map → SyncFailureError.""" + repo = MagicMock() + resolver = EntityActionResolver(entity_map={}, entity_repo=repo) + ctx = _sync_context() + e = _entity() + + with pytest.raises(SyncFailureError, match="not in entity_map"): + await resolver.resolve([e], ctx) diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py new file mode 100644 index 000000000..4aa665100 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py @@ -0,0 +1,120 @@ +"""Tests for EntityPipeline — DI wiring and orphan identification.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- + + +def test_constructor_stores_entity_repo(): + """entity_repo is stored on the instance.""" + repo = MagicMock() + pipeline = EntityPipeline( + entity_tracker=MagicMock(), + event_bus=MagicMock(), + action_resolver=MagicMock(), + action_dispatcher=MagicMock(), + entity_repo=repo, + ) + assert pipeline._entity_repo is repo + + +def test_constructor_initializes_batch_seq(): + """_batch_seq starts at 0.""" + pipeline = EntityPipeline( + entity_tracker=MagicMock(), + event_bus=MagicMock(), + action_resolver=MagicMock(), + action_dispatcher=MagicMock(), + entity_repo=MagicMock(), + ) + assert pipeline._batch_seq == 0 + + +# --------------------------------------------------------------------------- +# _identify_orphans — uses entity_repo +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_identify_orphans_uses_entity_repo(): + """_identify_orphans calls entity_repo.get_by_sync_id, not crud.""" + sync_id = uuid4() + repo = MagicMock() + + stored_entity_1 = MagicMock() + stored_entity_1.entity_id = "kept-1" + stored_entity_1.entity_definition_short_name = "stub" + + stored_entity_2 = MagicMock() + stored_entity_2.entity_id = "orphan-1" + stored_entity_2.entity_definition_short_name = "stub" + + repo.get_by_sync_id = AsyncMock(return_value=[stored_entity_1, stored_entity_2]) + + tracker = MagicMock() + tracker.get_all_encountered_ids_flat.return_value = {"kept-1"} + + pipeline = EntityPipeline( + entity_tracker=tracker, + event_bus=MagicMock(), + action_resolver=MagicMock(), + action_dispatcher=MagicMock(), + entity_repo=repo, + ) + + sync_context = MagicMock() + sync_context.sync = MagicMock() + sync_context.sync.id = sync_id + + with patch("airweave.db.session.get_db_context") as mock_db_ctx: + mock_db = AsyncMock() + mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) + + orphans = await pipeline._identify_orphans(sync_context) + + repo.get_by_sync_id.assert_awaited_once_with(db=mock_db, sync_id=sync_id) + assert orphans == {"stub": ["orphan-1"]} + + +@pytest.mark.asyncio +async def test_identify_orphans_empty_when_all_encountered(): + """No orphans when all stored entities were encountered.""" + repo = MagicMock() + + stored = MagicMock() + stored.entity_id = "e-1" + stored.entity_definition_short_name = "stub" + repo.get_by_sync_id = AsyncMock(return_value=[stored]) + + tracker = MagicMock() + tracker.get_all_encountered_ids_flat.return_value = {"e-1"} + + pipeline = EntityPipeline( + entity_tracker=tracker, + event_bus=MagicMock(), + action_resolver=MagicMock(), + action_dispatcher=MagicMock(), + entity_repo=repo, + ) + + sync_context = MagicMock() + sync_context.sync = MagicMock() + sync_context.sync.id = uuid4() + + with patch("airweave.db.session.get_db_context") as mock_db_ctx: + mock_db = AsyncMock() + mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) + + orphans = await pipeline._identify_orphans(sync_context) + + assert orphans == {} diff --git a/backend/airweave/domains/sync_pipeline/tests/test_factory.py b/backend/airweave/domains/sync_pipeline/tests/test_factory.py new file mode 100644 index 000000000..9c0b35fe7 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/tests/test_factory.py @@ -0,0 +1,155 @@ +"""Tests for SyncFactory — DI wiring and create_orchestrator edge cases.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from airweave.domains.sync_pipeline.factory import SyncFactory + + +def _build_factory(**overrides): + """Build a SyncFactory with mock deps, accepting per-test overrides.""" + defaults = { + "sc_repo": MagicMock(), + "event_bus": MagicMock(), + "usage_checker": MagicMock(), + "dense_embedder": MagicMock(), + "sparse_embedder": MagicMock(), + "entity_repo": MagicMock(), + } + defaults.update(overrides) + return SyncFactory(**defaults) + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- + + +def test_constructor_stores_all_deps(): + """All six injected deps are stored on the instance.""" + deps = { + "sc_repo": MagicMock(), + "event_bus": MagicMock(), + "usage_checker": MagicMock(), + "dense_embedder": MagicMock(), + "sparse_embedder": MagicMock(), + "entity_repo": MagicMock(), + } + f = SyncFactory(**deps) + assert f._sc_repo is deps["sc_repo"] + assert f._event_bus is deps["event_bus"] + assert f._usage_checker is deps["usage_checker"] + assert f._dense_embedder is deps["dense_embedder"] + assert f._sparse_embedder is deps["sparse_embedder"] + assert f._entity_repo is deps["entity_repo"] + + +# --------------------------------------------------------------------------- +# create_orchestrator — sc_repo returns None → NotFoundException +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_orchestrator_raises_when_source_connection_missing(): + """NotFoundException when sc_repo.get_by_sync_id returns None.""" + from airweave.core.exceptions import NotFoundException + + sc_repo = MagicMock() + sc_repo.get_by_sync_id = AsyncMock(return_value=None) + + factory = _build_factory(sc_repo=sc_repo) + + sync = MagicMock() + sync.id = uuid4() + sync.sync_config = None + sync_job = MagicMock() + sync_job.sync_config = None + collection = MagicMock() + collection.sync_config = None + connection = MagicMock() + ctx = MagicMock() + ctx.organization = MagicMock() + ctx.organization.id = uuid4() + db = AsyncMock() + + with pytest.raises(NotFoundException, match="Source connection record not found"): + await factory.create_orchestrator( + db=db, + sync=sync, + sync_job=sync_job, + collection=collection, + connection=connection, + ctx=ctx, + ) + + +# --------------------------------------------------------------------------- +# create_orchestrator — happy path wires entity_repo into resolver/pipeline +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_orchestrator_passes_entity_repo_to_pipeline(): + """entity_repo is forwarded to EntityActionResolver and EntityPipeline.""" + entity_repo = MagicMock() + sc_repo = MagicMock() + sc = MagicMock() + sc.id = uuid4() + sc_repo.get_by_sync_id = AsyncMock(return_value=sc) + + factory = _build_factory(sc_repo=sc_repo, entity_repo=entity_repo) + + sync = MagicMock() + sync.id = uuid4() + sync.sync_config = None + sync_job = MagicMock() + sync_job.id = uuid4() + sync_job.sync_config = None + collection = MagicMock() + collection.sync_config = None + collection.readable_id = uuid4() + connection = MagicMock() + ctx = MagicMock() + ctx.organization = MagicMock() + ctx.organization.id = uuid4() + db = AsyncMock() + + with ( + patch( + "airweave.domains.sync_pipeline.factory.SyncContextBuilder" + ) as mock_sc_builder, + patch( + "airweave.domains.sync_pipeline.factory.EntityDispatcherBuilder" + ) as mock_disp_builder, + patch( + "airweave.domains.sync_pipeline.factory.TrackingContextBuilder" + ) as mock_track_builder, + patch( + "airweave.domains.sync_pipeline.factory.SyncFactory._build_source", + new_callable=AsyncMock, + ) as mock_build_source, + patch( + "airweave.domains.sync_pipeline.factory.SyncFactory._build_destinations", + new_callable=AsyncMock, + ) as mock_build_destinations, + ): + mock_build_source.return_value = (MagicMock(), MagicMock()) + mock_build_destinations.return_value = ([], {}) + mock_track_builder.build = AsyncMock(return_value=MagicMock()) + mock_sc_builder.build = AsyncMock(return_value=MagicMock()) + mock_disp_builder.build = MagicMock(return_value=MagicMock()) + + orchestrator = await factory.create_orchestrator( + db=db, + sync=sync, + sync_job=sync_job, + collection=collection, + connection=connection, + ctx=ctx, + ) + + assert orchestrator is not None + assert orchestrator.entity_pipeline._entity_repo is entity_repo + assert orchestrator.entity_pipeline._resolver._entity_repo is entity_repo diff --git a/backend/airweave/platform/sync/actions/__init__.py b/backend/airweave/platform/sync/actions/__init__.py index 1d7879ff9..fef9b0a07 100644 --- a/backend/airweave/platform/sync/actions/__init__.py +++ b/backend/airweave/platform/sync/actions/__init__.py @@ -5,6 +5,8 @@ - access_control/: Access control action types, resolver, dispatcher Each domain has its own types, resolver, and dispatcher tailored to its needs. + +Entity re-exports are lazy to avoid circular imports with domains/sync_pipeline. """ from airweave.platform.sync.actions.access_control import ( @@ -15,16 +17,16 @@ ACKeepAction, ACUpdateAction, ) -from airweave.platform.sync.actions.entity import ( - EntityActionBatch, - EntityActionDispatcher, - EntityActionResolver, - EntityDeleteAction, - EntityDispatcherBuilder, - EntityInsertAction, - EntityKeepAction, - EntityUpdateAction, -) + + +def __getattr__(name: str): + """Lazy re-exports for entity action symbols.""" + from airweave.platform.sync.actions import entity as _entity_pkg + + if name in _entity_pkg.__all__: + return getattr(_entity_pkg, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ # Access control types @@ -35,13 +37,13 @@ # Access control resolver and dispatcher "ACActionResolver", "ACActionDispatcher", - # Entity types + # Entity types (lazy) "EntityActionBatch", "EntityDeleteAction", "EntityInsertAction", "EntityKeepAction", "EntityUpdateAction", - # Entity resolver and dispatcher + # Entity resolver and dispatcher (lazy) "EntityActionResolver", "EntityActionDispatcher", "EntityDispatcherBuilder", diff --git a/backend/airweave/platform/sync/actions/entity/__init__.py b/backend/airweave/platform/sync/actions/entity/__init__.py index caa8ab9a1..ec47afb06 100644 --- a/backend/airweave/platform/sync/actions/entity/__init__.py +++ b/backend/airweave/platform/sync/actions/entity/__init__.py @@ -1,27 +1,62 @@ """Entity action types, resolver, and dispatcher. Entity-specific action pipeline for sync operations. + +Re-exports are lazy to avoid circular imports: entity_action_resolver imports +from .types, which triggers this __init__. Using __getattr__ breaks the cycle. """ -from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher -from airweave.domains.sync_pipeline.entity_action_resolver import EntityActionResolver -from airweave.platform.sync.actions.entity.builder import EntityDispatcherBuilder -from airweave.platform.sync.actions.entity.types import ( - EntityActionBatch, - EntityDeleteAction, - EntityInsertAction, - EntityKeepAction, - EntityUpdateAction, -) + +def __getattr__(name: str): + """Lazy re-exports to avoid circular imports.""" + _map = { + "EntityActionDispatcher": ( + "airweave.domains.sync_pipeline.entity_action_dispatcher", + "EntityActionDispatcher", + ), + "EntityActionResolver": ( + "airweave.domains.sync_pipeline.entity_action_resolver", + "EntityActionResolver", + ), + "EntityDispatcherBuilder": ( + "airweave.platform.sync.actions.entity.builder", + "EntityDispatcherBuilder", + ), + "EntityActionBatch": ( + "airweave.platform.sync.actions.entity.types", + "EntityActionBatch", + ), + "EntityDeleteAction": ( + "airweave.platform.sync.actions.entity.types", + "EntityDeleteAction", + ), + "EntityInsertAction": ( + "airweave.platform.sync.actions.entity.types", + "EntityInsertAction", + ), + "EntityKeepAction": ( + "airweave.platform.sync.actions.entity.types", + "EntityKeepAction", + ), + "EntityUpdateAction": ( + "airweave.platform.sync.actions.entity.types", + "EntityUpdateAction", + ), + } + if name in _map: + import importlib + + module_path, attr = _map[name] + return getattr(importlib.import_module(module_path), attr) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ - # Types "EntityActionBatch", "EntityDeleteAction", "EntityInsertAction", "EntityKeepAction", "EntityUpdateAction", - # Resolver and Dispatcher "EntityActionResolver", "EntityActionDispatcher", "EntityDispatcherBuilder", From d695b6cca90148f040110b2a03fd7653a3b7f8e2 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 16:41:53 -0700 Subject: [PATCH 04/13] remove unnecessary comments --- backend/airweave/core/container/container.py | 1 - backend/airweave/core/container/factory.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index 60ae1780a..02da8978e 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -192,7 +192,6 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): sync_lifecycle: SyncLifecycleServiceProtocol sync_factory: SyncFactoryProtocol - # Entity repository (used by sync pipeline) entity_repo: EntityRepositoryProtocol # Temporal domain diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index f34fb4f0c..87176a303 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -379,7 +379,7 @@ def create_container(settings: Settings) -> Container: sparse_embedder = _create_sparse_embedder(sparse_embedder_registry) # ----------------------------------------------------------------- - # Sync factory + service (needs embedders, built after embedder init) + # Sync factory + service # ----------------------------------------------------------------- sync_factory = SyncFactory( sc_repo=source_deps["sc_repo"], @@ -396,7 +396,7 @@ def create_container(settings: Settings) -> Container: ) # ----------------------------------------------------------------- - # Collection service (needs collection_repo, sc_repo, sync_lifecycle, dense_registry) + # Collection service # ----------------------------------------------------------------- collection_service = CollectionService( collection_repo=source_deps["collection_repo"], From fa12be4bc9cc4c70bd0e8a85a1b2611a1d0bf023 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 17:30:47 -0700 Subject: [PATCH 05/13] refactor: consolidate sync pipeline into domain with DI Move all remaining platform/sync/ modules (handlers, processors, contexts, builders, config, types, actions, subscribers, tests) into domains/sync_pipeline/. Replace direct crud.* calls with injected repository protocols (EntityRepository, AccessControlMembershipRepository) and inject ChunkEmbedProcessor via constructor DI through the factory chain. - New domains/access_control/ domain with repo + protocol + fakes - EntityRepository extended with 4 bulk write methods - All tests updated to use injected mocks instead of module patches - Circular imports eliminated; only platform/sync/{arf,token_manager,web_fetcher} remain --- backend/airweave/api/v1/endpoints/admin.py | 2 +- backend/airweave/core/container/factory.py | 12 +- backend/airweave/core/sync_job_service.py | 2 +- .../domains/access_control/__init__.py | 1 + .../domains/access_control/fakes/__init__.py | 1 + .../access_control/fakes/repository.py | 137 ++++++++++++++++ .../domains/access_control/protocols.py | 100 ++++++++++++ .../domains/access_control/repository.py | 147 ++++++++++++++++++ .../airweave/domains/browse_tree/service.py | 2 +- .../domains/entities/entity_repository.py | 42 ++++- .../airweave/domains/entities/protocols.py | 41 ++++- .../access_control_dispatcher.py} | 8 +- .../sync_pipeline}/access_control_pipeline.py | 38 ++--- .../sync_pipeline/access_control_resolver.py} | 4 +- .../sync_pipeline}/async_helpers.py | 0 .../sync_pipeline/builders/__init__.py | 5 + .../sync_pipeline}/builders/destinations.py | 2 +- .../sync_pipeline}/builders/source.py | 10 +- .../sync_pipeline}/builders/sync.py | 4 +- .../sync_pipeline}/builders/tracking.py | 2 +- .../domains/sync_pipeline/config/__init__.py | 6 + .../sync_pipeline}/config/base.py | 0 .../sync_pipeline}/config/builder.py | 2 +- .../sync_pipeline/contexts/__init__.py | 5 + .../sync_pipeline}/contexts/infra.py | 0 .../sync_pipeline}/contexts/runtime.py | 4 +- .../sync_pipeline}/contexts/source.py | 2 +- .../sync_pipeline}/contexts/sync.py | 2 +- .../sync => domains/sync_pipeline}/cursor.py | 0 .../sync_pipeline/entity_action_dispatcher.py | 12 +- .../sync_pipeline/entity_action_resolver.py | 6 +- .../entity_dispatcher_builder.py} | 77 ++++----- .../domains/sync_pipeline/entity_pipeline.py | 14 +- .../sync_pipeline}/exceptions.py | 0 .../airweave/domains/sync_pipeline/factory.py | 54 ++++--- .../sync_pipeline/fakes/entity_repository.py | 27 +++- .../domains/sync_pipeline/fakes/factory.py | 2 +- .../sync_pipeline}/file_types.py | 0 .../sync_pipeline/handlers/__init__.py | 10 ++ .../handlers/access_control_postgres.py | 19 ++- .../sync_pipeline}/handlers/arf.py | 10 +- .../sync_pipeline}/handlers/destination.py | 24 +-- .../handlers/entity_postgres.py | 37 ++--- .../sync_pipeline}/handlers/protocol.py | 8 +- .../sync_pipeline}/orchestrator.py | 12 +- .../sync_pipeline/pipeline/__init__.py | 1 + .../pipeline/acl_membership_tracker.py | 0 .../pipeline/cleanup_service.py | 6 +- .../sync_pipeline}/pipeline/entity_tracker.py | 0 .../sync_pipeline}/pipeline/hash_computer.py | 8 +- .../sync_pipeline}/pipeline/text_builder.py | 8 +- .../sync_pipeline/processors/__init__.py | 5 + .../sync_pipeline}/processors/chunk_embed.py | 10 +- .../sync_pipeline}/processors/utils.py | 4 +- .../domains/sync_pipeline/protocols.py | 10 +- .../sync => domains/sync_pipeline}/stream.py | 0 .../sync_pipeline/subscribers/__init__.py | 5 + .../subscribers/progress_relay.py | 0 .../tests}/test_acl_membership_tracker.py | 2 +- .../tests}/test_acl_reconciliation.py | 85 +++++----- .../sync_pipeline/tests}/test_chunk_embed.py | 12 +- .../tests}/test_cleanup_service.py | 4 +- .../sync_pipeline/tests/test_config_base.py} | 2 +- .../tests/test_config_builder.py} | 4 +- .../tests}/test_destination_handler.py | 34 ++-- .../tests/test_entity_action_resolver.py | 4 +- .../sync_pipeline/tests/test_factory.py | 8 +- .../tests/test_progress_relay.py | 2 +- .../domains/sync_pipeline/types/__init__.py | 1 + .../types/access_control_actions.py} | 0 .../sync_pipeline/types/entity_actions.py} | 0 .../sync_pipeline}/worker_pool.py | 0 .../domains/syncs/fakes/sync_job_service.py | 2 +- .../domains/syncs/fakes/sync_service.py | 2 +- backend/airweave/domains/syncs/protocols.py | 4 +- backend/airweave/domains/syncs/service.py | 2 +- .../domains/syncs/sync_job_service.py | 2 +- .../syncs/tests/test_sync_job_service.py | 2 +- .../platform/access_control/broker.py | 29 ++-- .../airweave/platform/builders/__init__.py | 11 -- backend/airweave/platform/chunkers/code.py | 4 +- .../airweave/platform/chunkers/semantic.py | 4 +- .../airweave/platform/contexts/__init__.py | 14 -- .../platform/converters/html_converter.py | 4 +- .../converters/text_extractors/docx.py | 2 +- .../converters/text_extractors/pdf.py | 2 +- .../converters/text_extractors/pptx.py | 2 +- .../platform/converters/txt_converter.py | 4 +- .../platform/converters/web_converter.py | 2 +- .../platform/converters/xlsx_converter.py | 4 +- .../platform/ocr/mistral/compressor.py | 2 +- .../platform/ocr/mistral/converter.py | 2 +- .../platform/ocr/mistral/ocr_client.py | 2 +- .../platform/ocr/mistral/splitters.py | 2 +- .../airweave/platform/sources/google_drive.py | 2 +- .../sources/sharepoint2019v2/builders.py | 2 +- .../sources/sharepoint2019v2/source.py | 2 +- .../sources/sharepoint_online/builders.py | 2 +- .../sources/sharepoint_online/source.py | 2 +- backend/airweave/platform/sources/slack.py | 2 +- .../airweave/platform/storage/file_service.py | 2 +- .../platform/sync/actions/__init__.py | 50 ------ .../sync/actions/access_control/__init__.py | 28 ---- .../platform/sync/actions/entity/__init__.py | 63 -------- backend/airweave/platform/sync/arf/service.py | 4 +- .../airweave/platform/sync/config/__init__.py | 57 ------- .../platform/sync/handlers/__init__.py | 40 ----- .../platform/sync/pipeline/__init__.py | 39 ----- .../platform/sync/processors/__init__.py | 13 -- .../platform/sync/subscribers/__init__.py | 7 - .../sync/subscribers/tests/__init__.py | 1 - backend/airweave/platform/sync/web_fetcher.py | 2 +- .../platform/temporal/activities/sync.py | 2 +- .../temporal/worker/control_server.py | 2 +- backend/airweave/schemas/collection.py | 2 +- backend/airweave/schemas/sync.py | 2 +- backend/airweave/schemas/sync_job.py | 2 +- .../platform/converters/test_txt_converter.py | 2 +- 118 files changed, 859 insertions(+), 683 deletions(-) create mode 100644 backend/airweave/domains/access_control/__init__.py create mode 100644 backend/airweave/domains/access_control/fakes/__init__.py create mode 100644 backend/airweave/domains/access_control/fakes/repository.py create mode 100644 backend/airweave/domains/access_control/protocols.py create mode 100644 backend/airweave/domains/access_control/repository.py rename backend/airweave/{platform/sync/actions/access_control/dispatcher.py => domains/sync_pipeline/access_control_dispatcher.py} (87%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/access_control_pipeline.py (93%) rename backend/airweave/{platform/sync/actions/access_control/resolver.py => domains/sync_pipeline/access_control_resolver.py} (93%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/async_helpers.py (100%) create mode 100644 backend/airweave/domains/sync_pipeline/builders/__init__.py rename backend/airweave/{platform => domains/sync_pipeline}/builders/destinations.py (99%) rename backend/airweave/{platform => domains/sync_pipeline}/builders/source.py (97%) rename backend/airweave/{platform => domains/sync_pipeline}/builders/sync.py (96%) rename backend/airweave/{platform => domains/sync_pipeline}/builders/tracking.py (94%) create mode 100644 backend/airweave/domains/sync_pipeline/config/__init__.py rename backend/airweave/{platform/sync => domains/sync_pipeline}/config/base.py (100%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/config/builder.py (95%) create mode 100644 backend/airweave/domains/sync_pipeline/contexts/__init__.py rename backend/airweave/{platform => domains/sync_pipeline}/contexts/infra.py (100%) rename backend/airweave/{platform => domains/sync_pipeline}/contexts/runtime.py (89%) rename backend/airweave/{platform => domains/sync_pipeline}/contexts/source.py (87%) rename backend/airweave/{platform => domains/sync_pipeline}/contexts/sync.py (96%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/cursor.py (100%) rename backend/airweave/{platform/sync/actions/entity/builder.py => domains/sync_pipeline/entity_dispatcher_builder.py} (61%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/exceptions.py (100%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/file_types.py (100%) create mode 100644 backend/airweave/domains/sync_pipeline/handlers/__init__.py rename backend/airweave/{platform/sync => domains/sync_pipeline}/handlers/access_control_postgres.py (88%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/handlers/arf.py (93%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/handlers/destination.py (92%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/handlers/entity_postgres.py (90%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/handlers/protocol.py (92%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/orchestrator.py (98%) create mode 100644 backend/airweave/domains/sync_pipeline/pipeline/__init__.py rename backend/airweave/{platform/sync => domains/sync_pipeline}/pipeline/acl_membership_tracker.py (100%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/pipeline/cleanup_service.py (95%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/pipeline/entity_tracker.py (100%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/pipeline/hash_computer.py (97%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/pipeline/text_builder.py (98%) create mode 100644 backend/airweave/domains/sync_pipeline/processors/__init__.py rename backend/airweave/{platform/sync => domains/sync_pipeline}/processors/chunk_embed.py (96%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/processors/utils.py (88%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/stream.py (100%) create mode 100644 backend/airweave/domains/sync_pipeline/subscribers/__init__.py rename backend/airweave/{platform/sync => domains/sync_pipeline}/subscribers/progress_relay.py (100%) rename backend/{tests/unit/platform/sync/pipeline => airweave/domains/sync_pipeline/tests}/test_acl_membership_tracker.py (99%) rename backend/{tests/unit/platform/sync/pipeline => airweave/domains/sync_pipeline/tests}/test_acl_reconciliation.py (76%) rename backend/{tests/unit/platform/sync/processors => airweave/domains/sync_pipeline/tests}/test_chunk_embed.py (96%) rename backend/{tests/unit/platform/sync/pipeline => airweave/domains/sync_pipeline/tests}/test_cleanup_service.py (98%) rename backend/{tests/unit/platform/sync/sync_config/test_base.py => airweave/domains/sync_pipeline/tests/test_config_base.py} (99%) rename backend/{tests/unit/platform/sync/sync_config/test_builder.py => airweave/domains/sync_pipeline/tests/test_config_builder.py} (97%) rename backend/{tests/unit/platform/sync/handlers => airweave/domains/sync_pipeline/tests}/test_destination_handler.py (88%) rename backend/airweave/{platform/sync/subscribers => domains/sync_pipeline}/tests/test_progress_relay.py (99%) create mode 100644 backend/airweave/domains/sync_pipeline/types/__init__.py rename backend/airweave/{platform/sync/actions/access_control/types.py => domains/sync_pipeline/types/access_control_actions.py} (100%) rename backend/airweave/{platform/sync/actions/entity/types.py => domains/sync_pipeline/types/entity_actions.py} (100%) rename backend/airweave/{platform/sync => domains/sync_pipeline}/worker_pool.py (100%) delete mode 100644 backend/airweave/platform/builders/__init__.py delete mode 100644 backend/airweave/platform/contexts/__init__.py delete mode 100644 backend/airweave/platform/sync/actions/__init__.py delete mode 100644 backend/airweave/platform/sync/actions/access_control/__init__.py delete mode 100644 backend/airweave/platform/sync/actions/entity/__init__.py delete mode 100644 backend/airweave/platform/sync/config/__init__.py delete mode 100644 backend/airweave/platform/sync/handlers/__init__.py delete mode 100644 backend/airweave/platform/sync/pipeline/__init__.py delete mode 100644 backend/airweave/platform/sync/processors/__init__.py delete mode 100644 backend/airweave/platform/sync/subscribers/__init__.py delete mode 100644 backend/airweave/platform/sync/subscribers/tests/__init__.py diff --git a/backend/airweave/api/v1/endpoints/admin.py b/backend/airweave/api/v1/endpoints/admin.py index 8b894ca6b..d7c0b2d94 100644 --- a/backend/airweave/api/v1/endpoints/admin.py +++ b/backend/airweave/api/v1/endpoints/admin.py @@ -49,7 +49,7 @@ from airweave.models.organization import Organization from airweave.models.organization_billing import OrganizationBilling from airweave.models.user_organization import UserOrganization -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.schemas.organization_billing import BillingPlan, BillingStatus router = TrailingSlashRouter() diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 87176a303..f3bfa0ee2 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -90,7 +90,9 @@ from airweave.domains.sources.registry import SourceRegistry from airweave.domains.sources.service import SourceService from airweave.domains.sources.validation import SourceValidationService +from airweave.domains.access_control.repository import AccessControlMembershipRepository from airweave.domains.sync_pipeline.factory import SyncFactory +from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor from airweave.domains.syncs.service import SyncService from airweave.domains.syncs.sync_cursor_repository import SyncCursorRepository from airweave.domains.syncs.sync_job_repository import SyncJobRepository @@ -108,7 +110,7 @@ from airweave.domains.webhooks.service import WebhookServiceImpl from airweave.domains.webhooks.subscribers import WebhookEventSubscriber from airweave.platform.auth.settings import integration_settings -from airweave.platform.sync.subscribers.progress_relay import SyncProgressRelay +from airweave.domains.sync_pipeline.subscribers.progress_relay import SyncProgressRelay from airweave.platform.temporal.client import TemporalClient @@ -378,6 +380,12 @@ def create_container(settings: Settings) -> Container: dense_embedder = _create_dense_embedder(settings, dense_embedder_registry) sparse_embedder = _create_sparse_embedder(sparse_embedder_registry) + # ----------------------------------------------------------------- + # Access control membership repo + chunk embed processor + # ----------------------------------------------------------------- + acl_membership_repo = AccessControlMembershipRepository() + chunk_embed_processor = ChunkEmbedProcessor() + # ----------------------------------------------------------------- # Sync factory + service # ----------------------------------------------------------------- @@ -388,6 +396,8 @@ def create_container(settings: Settings) -> Container: dense_embedder=dense_embedder, sparse_embedder=sparse_embedder, entity_repo=sync_deps["entity_repo"], + acl_repo=acl_membership_repo, + processor=chunk_embed_processor, ) sync_service = SyncService( diff --git a/backend/airweave/core/sync_job_service.py b/backend/airweave/core/sync_job_service.py index c3e263083..1777b270f 100644 --- a/backend/airweave/core/sync_job_service.py +++ b/backend/airweave/core/sync_job_service.py @@ -10,7 +10,7 @@ from airweave.core.logging import logger from airweave.core.shared_models import SyncJobStatus from airweave.db.session import get_db_context -from airweave.platform.sync.pipeline.entity_tracker import SyncStats +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats class SyncJobService: diff --git a/backend/airweave/domains/access_control/__init__.py b/backend/airweave/domains/access_control/__init__.py new file mode 100644 index 000000000..f49bc4585 --- /dev/null +++ b/backend/airweave/domains/access_control/__init__.py @@ -0,0 +1 @@ +"""Access control domain — membership repository and protocols.""" diff --git a/backend/airweave/domains/access_control/fakes/__init__.py b/backend/airweave/domains/access_control/fakes/__init__.py new file mode 100644 index 000000000..3ef98e3ee --- /dev/null +++ b/backend/airweave/domains/access_control/fakes/__init__.py @@ -0,0 +1 @@ +"""Fakes for the access control domain.""" diff --git a/backend/airweave/domains/access_control/fakes/repository.py b/backend/airweave/domains/access_control/fakes/repository.py new file mode 100644 index 000000000..9c7fe8913 --- /dev/null +++ b/backend/airweave/domains/access_control/fakes/repository.py @@ -0,0 +1,137 @@ +"""Fake access control membership repository for testing.""" + +from typing import List +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave.models.access_control_membership import AccessControlMembership + + +class FakeAccessControlMembershipRepository: + """In-memory fake for AccessControlMembershipRepositoryProtocol.""" + + def __init__(self) -> None: + self._memberships: List[AccessControlMembership] = [] + + async def bulk_create( + self, + db: AsyncSession, + memberships: List, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> int: + return len(memberships) + + async def upsert( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + group_name: str, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> None: + pass + + async def delete_by_key( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + return 0 + + async def delete_by_group( + self, + db: AsyncSession, + *, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + return 0 + + async def get_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return [ + m + for m in self._memberships + if m.source_connection_id == source_connection_id + and m.organization_id == organization_id + ] + + async def bulk_delete(self, db: AsyncSession, ids: List[UUID]) -> int: + before = len(self._memberships) + self._memberships = [m for m in self._memberships if m.id not in ids] + return before - len(self._memberships) + + async def get_by_member( + self, + db: AsyncSession, + member_id: str, + member_type: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return [ + m + for m in self._memberships + if m.member_id == member_id + and m.member_type == member_type + and m.organization_id == organization_id + ] + + async def get_by_member_and_collection( + self, + db: AsyncSession, + member_id: str, + member_type: str, + readable_collection_id: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return [] + + async def get_memberships_by_groups( + self, + db: AsyncSession, + *, + group_ids: List[str], + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return [ + m + for m in self._memberships + if m.group_id in group_ids + and m.source_connection_id == source_connection_id + and m.organization_id == organization_id + ] + + async def delete_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + before = len(self._memberships) + self._memberships = [ + m + for m in self._memberships + if not ( + m.source_connection_id == source_connection_id + and m.organization_id == organization_id + ) + ] + return before - len(self._memberships) diff --git a/backend/airweave/domains/access_control/protocols.py b/backend/airweave/domains/access_control/protocols.py new file mode 100644 index 000000000..02c4a9440 --- /dev/null +++ b/backend/airweave/domains/access_control/protocols.py @@ -0,0 +1,100 @@ +"""Protocols for the access control domain.""" + +from typing import List, Protocol +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave.models.access_control_membership import AccessControlMembership + + +class AccessControlMembershipRepositoryProtocol(Protocol): + """Data access for access control memberships.""" + + async def bulk_create( + self, + db: AsyncSession, + memberships: List, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> int: ... + + async def upsert( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + group_name: str, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> None: ... + + async def delete_by_key( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: ... + + async def delete_by_group( + self, + db: AsyncSession, + *, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: ... + + async def get_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: ... + + async def bulk_delete( + self, + db: AsyncSession, + ids: List[UUID], + ) -> int: ... + + async def get_by_member( + self, + db: AsyncSession, + member_id: str, + member_type: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: ... + + async def get_by_member_and_collection( + self, + db: AsyncSession, + member_id: str, + member_type: str, + readable_collection_id: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: ... + + async def get_memberships_by_groups( + self, + db: AsyncSession, + *, + group_ids: List[str], + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: ... + + async def delete_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: ... diff --git a/backend/airweave/domains/access_control/repository.py b/backend/airweave/domains/access_control/repository.py new file mode 100644 index 000000000..b07fe6adc --- /dev/null +++ b/backend/airweave/domains/access_control/repository.py @@ -0,0 +1,147 @@ +"""Access control membership repository wrapping crud.access_control_membership.""" + +from typing import List +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave import crud +from airweave.models.access_control_membership import AccessControlMembership + + +class AccessControlMembershipRepository: + """Delegates to the crud.access_control_membership singleton.""" + + async def bulk_create( + self, + db: AsyncSession, + memberships: List, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> int: + return await crud.access_control_membership.bulk_create( + db, memberships, organization_id, source_connection_id, source_name + ) + + async def upsert( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + group_name: str, + organization_id: UUID, + source_connection_id: UUID, + source_name: str, + ) -> None: + return await crud.access_control_membership.upsert( + db, + member_id=member_id, + member_type=member_type, + group_id=group_id, + group_name=group_name, + organization_id=organization_id, + source_connection_id=source_connection_id, + source_name=source_name, + ) + + async def delete_by_key( + self, + db: AsyncSession, + *, + member_id: str, + member_type: str, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + return await crud.access_control_membership.delete_by_key( + db, + member_id=member_id, + member_type=member_type, + group_id=group_id, + source_connection_id=source_connection_id, + organization_id=organization_id, + ) + + async def delete_by_group( + self, + db: AsyncSession, + *, + group_id: str, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + return await crud.access_control_membership.delete_by_group( + db, + group_id=group_id, + source_connection_id=source_connection_id, + organization_id=organization_id, + ) + + async def get_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return await crud.access_control_membership.get_by_source_connection( + db, source_connection_id, organization_id + ) + + async def bulk_delete( + self, + db: AsyncSession, + ids: List[UUID], + ) -> int: + return await crud.access_control_membership.bulk_delete(db, ids) + + async def get_by_member( + self, + db: AsyncSession, + member_id: str, + member_type: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return await crud.access_control_membership.get_by_member( + db, member_id, member_type, organization_id + ) + + async def get_by_member_and_collection( + self, + db: AsyncSession, + member_id: str, + member_type: str, + readable_collection_id: str, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return await crud.access_control_membership.get_by_member_and_collection( + db, member_id, member_type, readable_collection_id, organization_id + ) + + async def get_memberships_by_groups( + self, + db: AsyncSession, + *, + group_ids: List[str], + source_connection_id: UUID, + organization_id: UUID, + ) -> List[AccessControlMembership]: + return await crud.access_control_membership.get_memberships_by_groups( + db, + group_ids=group_ids, + source_connection_id=source_connection_id, + organization_id=organization_id, + ) + + async def delete_by_source_connection( + self, + db: AsyncSession, + source_connection_id: UUID, + organization_id: UUID, + ) -> int: + return await crud.access_control_membership.delete_by_source_connection( + db, source_connection_id, organization_id + ) diff --git a/backend/airweave/domains/browse_tree/service.py b/backend/airweave/domains/browse_tree/service.py index 921fa3603..708c1d30e 100644 --- a/backend/airweave/domains/browse_tree/service.py +++ b/backend/airweave/domains/browse_tree/service.py @@ -25,7 +25,7 @@ from airweave.domains.sources.protocols import SourceLifecycleServiceProtocol from airweave.domains.syncs.protocols import SyncJobRepositoryProtocol, SyncRepositoryProtocol from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.schemas.sync_job import SyncJobCreate, SyncJobStatus diff --git a/backend/airweave/domains/entities/entity_repository.py b/backend/airweave/domains/entities/entity_repository.py index 35fa3f25b..fd43a3deb 100644 --- a/backend/airweave/domains/entities/entity_repository.py +++ b/backend/airweave/domains/entities/entity_repository.py @@ -1,19 +1,19 @@ """Entity repository wrapping crud.entity for sync pipeline usage.""" -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from airweave import crud from airweave.models.entity import Entity +from airweave.schemas.entity import EntityCreate class EntityRepository: """Delegates to the crud.entity singleton.""" async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: - """Get all entities for a specific sync.""" return await crud.entity.get_by_sync_id(db, sync_id) async def bulk_get_by_entity_sync_and_definition( @@ -23,7 +23,43 @@ async def bulk_get_by_entity_sync_and_definition( sync_id: UUID, entity_requests: list[Tuple[str, str]], ) -> Dict[Tuple[str, str], Entity]: - """Bulk-fetch entities by (entity_id, definition_short_name).""" return await crud.entity.bulk_get_by_entity_sync_and_definition( db, sync_id=sync_id, entity_requests=entity_requests ) + + async def bulk_create( + self, + db: AsyncSession, + *, + objs: List[EntityCreate], + ctx: Any, + ) -> List[Entity]: + return await crud.entity.bulk_create(db, objs=objs, ctx=ctx) + + async def bulk_update_hash( + self, + db: AsyncSession, + *, + rows: List[Tuple[UUID, str]], + ) -> None: + return await crud.entity.bulk_update_hash(db, rows=rows) + + async def bulk_remove( + self, + db: AsyncSession, + *, + ids: List[UUID], + ctx: Any, + ) -> List[Entity]: + return await crud.entity.bulk_remove(db, ids=ids, ctx=ctx) + + async def bulk_get_by_entity_and_sync( + self, + db: AsyncSession, + *, + sync_id: UUID, + entity_ids: List[str], + ) -> Dict[str, Entity]: + return await crud.entity.bulk_get_by_entity_and_sync( + db, sync_id=sync_id, entity_ids=entity_ids + ) diff --git a/backend/airweave/domains/entities/protocols.py b/backend/airweave/domains/entities/protocols.py index 5d7617f13..e6f71d857 100644 --- a/backend/airweave/domains/entities/protocols.py +++ b/backend/airweave/domains/entities/protocols.py @@ -1,6 +1,6 @@ """Protocols for the entities domain.""" -from typing import Dict, List, Protocol, Tuple +from typing import Any, Dict, List, Protocol, Tuple from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession @@ -30,10 +30,9 @@ async def get_counts_per_sync_and_type( class EntityRepositoryProtocol(Protocol): - """Entity read access used by the sync pipeline.""" + """Entity data access used by the sync pipeline.""" async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: - """Get all entities for a specific sync.""" ... async def bulk_get_by_entity_sync_and_definition( @@ -43,5 +42,39 @@ async def bulk_get_by_entity_sync_and_definition( sync_id: UUID, entity_requests: list[Tuple[str, str]], ) -> Dict[Tuple[str, str], Entity]: - """Bulk-fetch entities by (entity_id, entity_definition_short_name) for a sync.""" + ... + + async def bulk_create( + self, + db: AsyncSession, + *, + objs: list, + ctx: Any, + ) -> List[Entity]: + ... + + async def bulk_update_hash( + self, + db: AsyncSession, + *, + rows: List[Tuple[UUID, str]], + ) -> None: + ... + + async def bulk_remove( + self, + db: AsyncSession, + *, + ids: List[UUID], + ctx: Any, + ) -> List[Entity]: + ... + + async def bulk_get_by_entity_and_sync( + self, + db: AsyncSession, + *, + sync_id: UUID, + entity_ids: List[str], + ) -> Dict[str, Entity]: ... diff --git a/backend/airweave/platform/sync/actions/access_control/dispatcher.py b/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py similarity index 87% rename from backend/airweave/platform/sync/actions/access_control/dispatcher.py rename to backend/airweave/domains/sync_pipeline/access_control_dispatcher.py index df82e30b4..45e6a3889 100644 --- a/backend/airweave/platform/sync/actions/access_control/dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py @@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, List -from airweave.platform.sync.actions.access_control.types import ACActionBatch -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.sync.handlers.protocol import ACActionHandler + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler class ACActionDispatcher: diff --git a/backend/airweave/platform/sync/access_control_pipeline.py b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py similarity index 93% rename from backend/airweave/platform/sync/access_control_pipeline.py rename to backend/airweave/domains/sync_pipeline/access_control_pipeline.py index 82d84635c..41c00ebd8 100644 --- a/backend/airweave/platform/sync/access_control_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py @@ -11,46 +11,36 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Set, Tuple -from airweave import crud from airweave.db.session import get_db_context +from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol from airweave.platform.access_control.schemas import ( ACLChangeType, MembershipTuple, ) -from airweave.platform.sync.actions.access_control import ACActionDispatcher, ACActionResolver -from airweave.platform.sync.pipeline.acl_membership_tracker import ACLMembershipTracker +from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher +from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver +from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker from airweave.platform.utils.error_utils import get_error_message if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class AccessControlPipeline: - """Orchestrates membership processing with full and incremental sync support. - - Full sync path: - 1. Collect all memberships from source - 2. Dedupe + resolve + dispatch (upsert to Postgres) - 3. Delete orphan memberships (revoked permissions) - 4. Seed DirSync cookie for future incremental syncs - - Incremental sync path: - 1. Call source.get_acl_changes() with DirSync cookie - 2. Apply ADD changes (upsert) and REMOVE changes (delete by key) - 3. Update cursor with new DirSync cookie - """ + """Orchestrates membership processing with full and incremental sync support.""" def __init__( self, resolver: ACActionResolver, dispatcher: ACActionDispatcher, tracker: ACLMembershipTracker, + acl_repo: AccessControlMembershipRepositoryProtocol, ): - """Initialize pipeline with injected components.""" self._resolver = resolver self._dispatcher = dispatcher self._tracker = tracker + self._acl_repo = acl_repo async def process( self, @@ -329,7 +319,7 @@ async def _process_incremental( # Handle deleted AD groups -- immediately remove all memberships # so that revoked access is reflected without waiting for a full sync for group_id in deleted_group_ids: - deleted = await crud.access_control_membership.delete_by_group( + deleted = await self._acl_repo.delete_by_group( db, group_id=group_id, source_connection_id=sync_context.source_connection_id, @@ -370,7 +360,7 @@ async def _apply_membership_changes( for change in result.changes: if change.change_type == ACLChangeType.ADD: - await crud.access_control_membership.upsert( + await self._acl_repo.upsert( db, member_id=change.member_id, member_type=change.member_type, @@ -383,7 +373,7 @@ async def _apply_membership_changes( adds += 1 elif change.change_type == ACLChangeType.REMOVE: - await crud.access_control_membership.delete_by_key( + await self._acl_repo.delete_by_key( db, member_id=change.member_id, member_type=change.member_type, @@ -472,7 +462,7 @@ async def _cleanup_orphan_memberships( be deleted to prevent unauthorized access. """ async with get_db_context() as db: - stored_memberships = await crud.access_control_membership.get_by_source_connection( + stored_memberships = await self._acl_repo.get_by_source_connection( db=db, source_connection_id=sync_context.source_connection_id, organization_id=sync_context.organization_id, @@ -497,7 +487,7 @@ async def _cleanup_orphan_memberships( ) async with get_db_context() as db: - deleted_count = await crud.access_control_membership.bulk_delete( + deleted_count = await self._acl_repo.bulk_delete( db=db, ids=[m.id for m in orphans], ) diff --git a/backend/airweave/platform/sync/actions/access_control/resolver.py b/backend/airweave/domains/sync_pipeline/access_control_resolver.py similarity index 93% rename from backend/airweave/platform/sync/actions/access_control/resolver.py rename to backend/airweave/domains/sync_pipeline/access_control_resolver.py index bd62b87e1..2fe41b1b3 100644 --- a/backend/airweave/platform/sync/actions/access_control/resolver.py +++ b/backend/airweave/domains/sync_pipeline/access_control_resolver.py @@ -7,13 +7,13 @@ from typing import TYPE_CHECKING, List from airweave.platform.access_control.schemas import MembershipTuple -from airweave.platform.sync.actions.access_control.types import ( +from airweave.domains.sync_pipeline.types.access_control_actions import ( ACActionBatch, ACUpsertAction, ) if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts import SyncContext class ACActionResolver: diff --git a/backend/airweave/platform/sync/async_helpers.py b/backend/airweave/domains/sync_pipeline/async_helpers.py similarity index 100% rename from backend/airweave/platform/sync/async_helpers.py rename to backend/airweave/domains/sync_pipeline/async_helpers.py diff --git a/backend/airweave/domains/sync_pipeline/builders/__init__.py b/backend/airweave/domains/sync_pipeline/builders/__init__.py new file mode 100644 index 000000000..861e8f0f1 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/builders/__init__.py @@ -0,0 +1,5 @@ +"""Sync pipeline builders — factory internals for building sync components.""" + +from airweave.domains.sync_pipeline.builders.sync import SyncContextBuilder + +__all__ = ["SyncContextBuilder"] diff --git a/backend/airweave/platform/builders/destinations.py b/backend/airweave/domains/sync_pipeline/builders/destinations.py similarity index 99% rename from backend/airweave/platform/builders/destinations.py rename to backend/airweave/domains/sync_pipeline/builders/destinations.py index fb85aaf39..f4f387f9b 100644 --- a/backend/airweave/platform/builders/destinations.py +++ b/backend/airweave/domains/sync_pipeline/builders/destinations.py @@ -16,7 +16,7 @@ from airweave.platform.destinations._base import BaseDestination from airweave.platform.destinations.vespa import VespaDestination from airweave.platform.entities._base import BaseEntity -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig class DestinationsContextBuilder: diff --git a/backend/airweave/platform/builders/source.py b/backend/airweave/domains/sync_pipeline/builders/source.py similarity index 97% rename from backend/airweave/platform/builders/source.py rename to backend/airweave/domains/sync_pipeline/builders/source.py index b8590a244..51bd131e8 100644 --- a/backend/airweave/platform/builders/source.py +++ b/backend/airweave/domains/sync_pipeline/builders/source.py @@ -23,11 +23,11 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.domains.browse_tree.repository import NodeSelectionRepository from airweave.domains.browse_tree.types import NodeSelectionData -from airweave.platform.contexts.infra import InfraContext -from airweave.platform.contexts.source import SourceContext +from airweave.domains.sync_pipeline.contexts.infra import InfraContext +from airweave.domains.sync_pipeline.contexts.source import SourceContext from airweave.platform.sources._base import BaseSource -from airweave.platform.sync.config import SyncConfig -from airweave.platform.sync.cursor import SyncCursor +from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.cursor import SyncCursor class SourceContextBuilder: @@ -209,7 +209,7 @@ def _validate_not_completed_snapshot(source_connection_obj) -> None: SnapshotConfig(**(source_connection_obj.config_fields or {})) # Config is a valid SnapshotConfig but short_name is not "snapshot" # → this is a completed snapshot, can't re-sync - from airweave.platform.sync.exceptions import SyncFailureError + from airweave.domains.sync_pipeline.exceptions import SyncFailureError raise SyncFailureError( f"Cannot re-sync a completed snapshot source connection " diff --git a/backend/airweave/platform/builders/sync.py b/backend/airweave/domains/sync_pipeline/builders/sync.py similarity index 96% rename from backend/airweave/platform/builders/sync.py rename to backend/airweave/domains/sync_pipeline/builders/sync.py index d1db27996..5662a715c 100644 --- a/backend/airweave/platform/builders/sync.py +++ b/backend/airweave/domains/sync_pipeline/builders/sync.py @@ -12,8 +12,8 @@ from airweave import schemas from airweave.core.context import BaseContext from airweave.core.logging import ContextualLogger, LoggerConfigurator -from airweave.platform.contexts.sync import SyncContext -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.contexts.sync import SyncContext +from airweave.domains.sync_pipeline.config import SyncConfig class SyncContextBuilder: diff --git a/backend/airweave/platform/builders/tracking.py b/backend/airweave/domains/sync_pipeline/builders/tracking.py similarity index 94% rename from backend/airweave/platform/builders/tracking.py rename to backend/airweave/domains/sync_pipeline/builders/tracking.py index fb73ca1f0..bbd70e161 100644 --- a/backend/airweave/platform/builders/tracking.py +++ b/backend/airweave/domains/sync_pipeline/builders/tracking.py @@ -8,7 +8,7 @@ from airweave import crud, schemas from airweave.core.context import BaseContext from airweave.core.logging import ContextualLogger -from airweave.platform.sync.pipeline.entity_tracker import EntityTracker +from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker class TrackingContextBuilder: diff --git a/backend/airweave/domains/sync_pipeline/config/__init__.py b/backend/airweave/domains/sync_pipeline/config/__init__.py new file mode 100644 index 000000000..baedb7ce7 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/config/__init__.py @@ -0,0 +1,6 @@ +"""Sync pipeline configuration.""" + +from airweave.domains.sync_pipeline.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.builder import SyncConfigBuilder + +__all__ = ["SyncConfig", "SyncConfigBuilder"] diff --git a/backend/airweave/platform/sync/config/base.py b/backend/airweave/domains/sync_pipeline/config/base.py similarity index 100% rename from backend/airweave/platform/sync/config/base.py rename to backend/airweave/domains/sync_pipeline/config/base.py diff --git a/backend/airweave/platform/sync/config/builder.py b/backend/airweave/domains/sync_pipeline/config/builder.py similarity index 95% rename from backend/airweave/platform/sync/config/builder.py rename to backend/airweave/domains/sync_pipeline/config/builder.py index d1aec186d..f6726fd05 100644 --- a/backend/airweave/platform/sync/config/builder.py +++ b/backend/airweave/domains/sync_pipeline/config/builder.py @@ -10,7 +10,7 @@ from typing import Optional -from airweave.platform.sync.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.base import SyncConfig class SyncConfigBuilder: diff --git a/backend/airweave/domains/sync_pipeline/contexts/__init__.py b/backend/airweave/domains/sync_pipeline/contexts/__init__.py new file mode 100644 index 000000000..8db0f4068 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/contexts/__init__.py @@ -0,0 +1,5 @@ +"""Sync pipeline contexts — data containers for sync execution.""" + +from airweave.domains.sync_pipeline.contexts.sync import SyncContext + +__all__ = ["SyncContext"] diff --git a/backend/airweave/platform/contexts/infra.py b/backend/airweave/domains/sync_pipeline/contexts/infra.py similarity index 100% rename from backend/airweave/platform/contexts/infra.py rename to backend/airweave/domains/sync_pipeline/contexts/infra.py diff --git a/backend/airweave/platform/contexts/runtime.py b/backend/airweave/domains/sync_pipeline/contexts/runtime.py similarity index 89% rename from backend/airweave/platform/contexts/runtime.py rename to backend/airweave/domains/sync_pipeline/contexts/runtime.py index 41205fb89..c0e3a15fd 100644 --- a/backend/airweave/platform/contexts/runtime.py +++ b/backend/airweave/domains/sync_pipeline/contexts/runtime.py @@ -14,8 +14,8 @@ from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.platform.destinations._base import BaseDestination from airweave.platform.sources._base import BaseSource - from airweave.platform.sync.cursor import SyncCursor - from airweave.platform.sync.pipeline.entity_tracker import EntityTracker + from airweave.domains.sync_pipeline.cursor import SyncCursor + from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker @dataclass diff --git a/backend/airweave/platform/contexts/source.py b/backend/airweave/domains/sync_pipeline/contexts/source.py similarity index 87% rename from backend/airweave/platform/contexts/source.py rename to backend/airweave/domains/sync_pipeline/contexts/source.py index 2fae9b749..85eaa27b1 100644 --- a/backend/airweave/platform/contexts/source.py +++ b/backend/airweave/domains/sync_pipeline/contexts/source.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from airweave.platform.sources._base import BaseSource - from airweave.platform.sync.cursor import SyncCursor + from airweave.domains.sync_pipeline.cursor import SyncCursor @dataclass diff --git a/backend/airweave/platform/contexts/sync.py b/backend/airweave/domains/sync_pipeline/contexts/sync.py similarity index 96% rename from backend/airweave/platform/contexts/sync.py rename to backend/airweave/domains/sync_pipeline/contexts/sync.py index 1b54277ff..03050a357 100644 --- a/backend/airweave/platform/contexts/sync.py +++ b/backend/airweave/domains/sync_pipeline/contexts/sync.py @@ -7,7 +7,7 @@ from airweave import schemas from airweave.core.context import BaseContext from airweave.platform.entities._base import BaseEntity -from airweave.platform.sync.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.base import SyncConfig @dataclass diff --git a/backend/airweave/platform/sync/cursor.py b/backend/airweave/domains/sync_pipeline/cursor.py similarity index 100% rename from backend/airweave/platform/sync/cursor.py rename to backend/airweave/domains/sync_pipeline/cursor.py diff --git a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py index b37bde8ef..f320533eb 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py @@ -7,14 +7,14 @@ import asyncio from typing import TYPE_CHECKING, List -from airweave.platform.sync.actions.entity.types import EntityActionBatch -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.entity_postgres import EntityPostgresHandler -from airweave.platform.sync.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class EntityActionDispatcher: diff --git a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py index c6c25793d..f9e5a8c4f 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py @@ -11,17 +11,17 @@ from airweave.db.session import get_db_context from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.platform.entities._base import BaseEntity, DeletionEntity -from airweave.platform.sync.actions.entity.types import ( +from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityKeepAction, EntityUpdateAction, ) -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts import SyncContext class EntityActionResolver: diff --git a/backend/airweave/platform/sync/actions/entity/builder.py b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py similarity index 61% rename from backend/airweave/platform/sync/actions/entity/builder.py rename to backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py index 225d32ca3..7ec330d4f 100644 --- a/backend/airweave/platform/sync/actions/entity/builder.py +++ b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py @@ -3,63 +3,50 @@ from typing import List, Optional from airweave.core.logging import ContextualLogger +from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher from airweave.platform.destinations._base import BaseDestination -from airweave.platform.sync.config import SyncConfig -from airweave.platform.sync.handlers.arf import ArfHandler -from airweave.platform.sync.handlers.destination import DestinationHandler -from airweave.platform.sync.handlers.entity_postgres import EntityPostgresHandler -from airweave.platform.sync.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.handlers.arf import ArfHandler +from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler +from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol class EntityDispatcherBuilder: """Builds entity action dispatcher with configured handlers.""" - @classmethod + def __init__( + self, + processor: ChunkEmbedProcessorProtocol, + entity_repo: EntityRepositoryProtocol, + ) -> None: + self._processor = processor + self._entity_repo = entity_repo + def build( - cls, + self, destinations: List[BaseDestination], execution_config: Optional[SyncConfig] = None, logger: Optional[ContextualLogger] = None, ) -> EntityActionDispatcher: - """Build dispatcher with handlers based on config. - - Args: - destinations: Destination instances - execution_config: Optional config to enable/disable handlers - logger: Optional logger for logging handler creation - - Returns: - EntityActionDispatcher with configured handlers. - """ - handlers = cls._build_handlers(destinations, execution_config, logger) + handlers = self._build_handlers(destinations, execution_config, logger) return EntityActionDispatcher(handlers=handlers) - @classmethod def build_for_cleanup( - cls, + self, destinations: List[BaseDestination], logger: Optional[ContextualLogger] = None, ) -> EntityActionDispatcher: - """Build dispatcher for cleanup operations (all handlers enabled). + return self.build(destinations=destinations, execution_config=None, logger=logger) - Args: - destinations: Destinations context - logger: Optional logger - - Returns: - EntityActionDispatcher for cleanup. - """ - return cls.build(destinations=destinations, execution_config=None, logger=logger) - - @classmethod def _build_handlers( - cls, + self, destinations: List[BaseDestination], execution_config: Optional[SyncConfig], logger: Optional[ContextualLogger], ) -> List[EntityActionHandler]: - """Build handler list based on config.""" enable_vector = ( execution_config.handlers.enable_vector_handlers if execution_config else True ) @@ -70,29 +57,29 @@ def _build_handlers( handlers: List[EntityActionHandler] = [] - cls._add_destination_handler(handlers, destinations, enable_vector, logger) - cls._add_arf_handler(handlers, enable_arf, logger) - cls._add_postgres_handler(handlers, enable_postgres, logger) + self._add_destination_handler(handlers, destinations, enable_vector, logger) + self._add_arf_handler(handlers, enable_arf, logger) + self._add_postgres_handler(handlers, enable_postgres, logger) if not handlers and logger: logger.warning("No handlers created - sync will fetch entities but not persist them") return handlers - @classmethod def _add_destination_handler( - cls, + self, handlers: List[EntityActionHandler], destinations: List[BaseDestination], enabled: bool, logger: Optional[ContextualLogger], ) -> None: - """Add destination handler if enabled and destinations exist.""" if not destinations: return if enabled: - handlers.append(DestinationHandler(destinations=destinations)) + handlers.append( + DestinationHandler(destinations=destinations, processor=self._processor) + ) if logger: dest_names = [d.__class__.__name__ for d in destinations] logger.info(f"Created DestinationHandler for {dest_names}") @@ -102,14 +89,12 @@ def _add_destination_handler( f"{len(destinations)} destination(s)" ) - @classmethod + @staticmethod def _add_arf_handler( - cls, handlers: List[EntityActionHandler], enabled: bool, logger: Optional[ContextualLogger], ) -> None: - """Add ARF handler if enabled.""" if enabled: handlers.append(ArfHandler()) if logger: @@ -117,16 +102,14 @@ def _add_arf_handler( elif logger: logger.info("Skipping ArfHandler (disabled by execution_config)") - @classmethod def _add_postgres_handler( - cls, + self, handlers: List[EntityActionHandler], enabled: bool, logger: Optional[ContextualLogger], ) -> None: - """Add Postgres metadata handler if enabled (always last).""" if enabled: - handlers.append(EntityPostgresHandler()) + handlers.append(EntityPostgresHandler(entity_repo=self._entity_repo)) if logger: logger.debug("Added EntityPostgresHandler") elif logger: diff --git a/backend/airweave/domains/sync_pipeline/entity_pipeline.py b/backend/airweave/domains/sync_pipeline/entity_pipeline.py index 53b29405f..8347e2e3b 100644 --- a/backend/airweave/domains/sync_pipeline/entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/entity_pipeline.py @@ -22,14 +22,14 @@ EntityActionDispatcherProtocol, EntityActionResolverProtocol, ) -from airweave.platform.contexts import SyncContext -from airweave.platform.contexts.runtime import SyncRuntime +from airweave.domains.sync_pipeline.contexts import SyncContext +from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.platform.entities._base import BaseEntity -from airweave.platform.sync.actions.entity.types import EntityActionBatch -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.pipeline.cleanup_service import cleanup_service -from airweave.platform.sync.pipeline.entity_tracker import EntityTracker -from airweave.platform.sync.pipeline.hash_computer import hash_computer +from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.pipeline.cleanup_service import cleanup_service +from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker +from airweave.domains.sync_pipeline.pipeline.hash_computer import hash_computer if TYPE_CHECKING: from airweave.core.protocols.event_bus import EventBus diff --git a/backend/airweave/platform/sync/exceptions.py b/backend/airweave/domains/sync_pipeline/exceptions.py similarity index 100% rename from backend/airweave/platform/sync/exceptions.py rename to backend/airweave/domains/sync_pipeline/exceptions.py diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index bf4a6a54d..93ef08517 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -20,26 +20,26 @@ from airweave.core.context import BaseContext from airweave.core.logging import LoggerConfigurator, logger from airweave.core.protocols.event_bus import EventBus +from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol from airweave.domains.usage.protocols import UsageLimitCheckerProtocol -from airweave.platform.builders import SyncContextBuilder -from airweave.platform.builders.tracking import TrackingContextBuilder -from airweave.platform.contexts.runtime import SyncRuntime -from airweave.platform.sync.access_control_pipeline import AccessControlPipeline -from airweave.platform.sync.actions import ( - ACActionDispatcher, - ACActionResolver, - EntityDispatcherBuilder, -) -from airweave.platform.sync.config import SyncConfig, SyncConfigBuilder -from airweave.platform.sync.handlers import ACPostgresHandler -from airweave.platform.sync.orchestrator import SyncOrchestrator -from airweave.platform.sync.pipeline.acl_membership_tracker import ACLMembershipTracker -from airweave.platform.sync.pipeline.entity_tracker import EntityTracker -from airweave.platform.sync.stream import AsyncSourceStream -from airweave.platform.sync.worker_pool import AsyncWorkerPool +from airweave.domains.sync_pipeline.builders import SyncContextBuilder +from airweave.domains.sync_pipeline.builders.tracking import TrackingContextBuilder +from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime +from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher +from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver +from airweave.domains.sync_pipeline.entity_dispatcher_builder import EntityDispatcherBuilder +from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder +from airweave.domains.sync_pipeline.handlers import ACPostgresHandler +from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator +from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker +from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker +from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol +from airweave.domains.sync_pipeline.stream import AsyncSourceStream +from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from .entity_action_resolver import EntityActionResolver from .entity_pipeline import EntityPipeline @@ -60,14 +60,17 @@ def __init__( dense_embedder: DenseEmbedderProtocol, sparse_embedder: SparseEmbedderProtocol, entity_repo: EntityRepositoryProtocol, + acl_repo: AccessControlMembershipRepositoryProtocol, + processor: ChunkEmbedProcessorProtocol, ) -> None: - """Initialize with all deployment-wide dependencies.""" self._sc_repo = sc_repo self._event_bus = event_bus self._usage_checker = usage_checker self._dense_embedder = dense_embedder self._sparse_embedder = sparse_embedder self._entity_repo = entity_repo + self._acl_repo = acl_repo + self._processor = processor async def create_orchestrator( self, @@ -156,7 +159,11 @@ async def create_orchestrator( logger.debug(f"Context + runtime built in {time.time() - init_start:.2f}s") - dispatcher = EntityDispatcherBuilder.build( + dispatcher_builder = EntityDispatcherBuilder( + processor=self._processor, + entity_repo=self._entity_repo, + ) + dispatcher = dispatcher_builder.build( destinations=runtime.destinations, execution_config=resolved_config, logger=sync_context.logger, @@ -177,12 +184,15 @@ async def create_orchestrator( access_control_pipeline = AccessControlPipeline( resolver=ACActionResolver(), - dispatcher=ACActionDispatcher(handlers=[ACPostgresHandler()]), + dispatcher=ACActionDispatcher( + handlers=[ACPostgresHandler(acl_repo=self._acl_repo)] + ), tracker=ACLMembershipTracker( source_connection_id=sync_context.source_connection_id, organization_id=sync_context.organization_id, logger=sync_context.logger, ), + acl_repo=self._acl_repo, ) worker_pool = AsyncWorkerPool(logger=sync_context.logger) @@ -212,8 +222,8 @@ async def create_orchestrator( @staticmethod async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_config): """Build source and cursor. Returns (source, cursor) tuple.""" - from airweave.platform.builders.source import SourceContextBuilder - from airweave.platform.contexts.infra import InfraContext + from airweave.domains.sync_pipeline.builders.source import SourceContextBuilder + from airweave.domains.sync_pipeline.contexts.infra import InfraContext sync_logger = LoggerConfigurator.configure_logger( "airweave.platform.sync.source_build", @@ -237,7 +247,7 @@ async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_conf @staticmethod async def _build_destinations(db, sync, collection, ctx, execution_config): """Build destinations and entity map. Returns (destinations, entity_map) tuple.""" - from airweave.platform.builders.destinations import DestinationsContextBuilder + from airweave.domains.sync_pipeline.builders.destinations import DestinationsContextBuilder dest_logger = LoggerConfigurator.configure_logger( "airweave.platform.sync.dest_build", diff --git a/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py index ee7bfc62d..dc86d5d55 100644 --- a/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py +++ b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py @@ -1,6 +1,6 @@ """Fake entity repository for testing.""" -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession @@ -25,3 +25,28 @@ async def bulk_get_by_entity_sync_and_definition( entity_requests: list[Tuple[str, str]], ) -> Dict[Tuple[str, str], Entity]: return {} + + async def bulk_create( + self, db: AsyncSession, *, objs: list, ctx: Any + ) -> List[Entity]: + return [] + + async def bulk_update_hash( + self, db: AsyncSession, *, rows: List[Tuple[UUID, str]] + ) -> None: + pass + + async def bulk_remove( + self, db: AsyncSession, *, ids: List[UUID], ctx: Any + ) -> List[Entity]: + self._entities = [e for e in self._entities if e.id not in ids] + return [] + + async def bulk_get_by_entity_and_sync( + self, db: AsyncSession, *, sync_id: UUID, entity_ids: List[str] + ) -> Dict[str, Entity]: + return { + e.entity_id: e + for e in self._entities + if e.sync_id == sync_id and e.entity_id in entity_ids + } diff --git a/backend/airweave/domains/sync_pipeline/fakes/factory.py b/backend/airweave/domains/sync_pipeline/fakes/factory.py index 011e5c5d7..440cc7e8b 100644 --- a/backend/airweave/domains/sync_pipeline/fakes/factory.py +++ b/backend/airweave/domains/sync_pipeline/fakes/factory.py @@ -7,7 +7,7 @@ from airweave import schemas from airweave.core.context import BaseContext -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig class FakeSyncFactory: diff --git a/backend/airweave/platform/sync/file_types.py b/backend/airweave/domains/sync_pipeline/file_types.py similarity index 100% rename from backend/airweave/platform/sync/file_types.py rename to backend/airweave/domains/sync_pipeline/file_types.py diff --git a/backend/airweave/domains/sync_pipeline/handlers/__init__.py b/backend/airweave/domains/sync_pipeline/handlers/__init__.py new file mode 100644 index 000000000..3106c83a9 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/handlers/__init__.py @@ -0,0 +1,10 @@ +"""Sync pipeline handlers — entity and access control action handlers.""" + +from airweave.domains.sync_pipeline.handlers.access_control_postgres import ACPostgresHandler +from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler, EntityActionHandler + +__all__ = [ + "EntityActionHandler", + "ACActionHandler", + "ACPostgresHandler", +] diff --git a/backend/airweave/platform/sync/handlers/access_control_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py similarity index 88% rename from backend/airweave/platform/sync/handlers/access_control_postgres.py rename to backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py index 1147ed6a2..63aa7c2d6 100644 --- a/backend/airweave/platform/sync/handlers/access_control_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py @@ -6,28 +6,27 @@ from typing import TYPE_CHECKING, List -from airweave import crud from airweave.db.session import get_db_context -from airweave.platform.sync.actions.access_control import ( +from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol +from airweave.domains.sync_pipeline.types.access_control_actions import ( ACActionBatch, ACDeleteAction, ACInsertAction, ACUpdateAction, ACUpsertAction, ) -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.protocol import ACActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts import SyncContext class ACPostgresHandler(ACActionHandler): - """Persists access control memberships to PostgreSQL. + """Persists access control memberships to PostgreSQL.""" - Implements ACActionHandler protocol. Structured with separate methods - per action type for easy extensibility when we add delete/update actions. - """ + def __init__(self, acl_repo: AccessControlMembershipRepositoryProtocol) -> None: + self._acl_repo = acl_repo @property def name(self) -> str: @@ -118,7 +117,7 @@ async def handle_upserts( batch = memberships[i : i + BATCH_SIZE] async with get_db_context() as db: - count = await crud.access_control_membership.bulk_create( + count = await self._acl_repo.bulk_create( db=db, memberships=batch, organization_id=sync_context.organization_id, diff --git a/backend/airweave/platform/sync/handlers/arf.py b/backend/airweave/domains/sync_pipeline/handlers/arf.py similarity index 93% rename from backend/airweave/platform/sync/handlers/arf.py rename to backend/airweave/domains/sync_pipeline/handlers/arf.py index 67cd41f8e..aed7b8018 100644 --- a/backend/airweave/platform/sync/handlers/arf.py +++ b/backend/airweave/domains/sync_pipeline/handlers/arf.py @@ -6,18 +6,18 @@ from typing import TYPE_CHECKING, List -from airweave.platform.sync.actions.entity.types import ( +from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.platform.entities import BaseEntity diff --git a/backend/airweave/platform/sync/handlers/destination.py b/backend/airweave/domains/sync_pipeline/handlers/destination.py similarity index 92% rename from backend/airweave/platform/sync/handlers/destination.py rename to backend/airweave/domains/sync_pipeline/handlers/destination.py index c20eba993..adba71d9a 100644 --- a/backend/airweave/platform/sync/handlers/destination.py +++ b/backend/airweave/domains/sync_pipeline/handlers/destination.py @@ -11,24 +11,22 @@ import httpx from airweave.platform.destinations._base import BaseDestination -from airweave.platform.sync.actions.entity.types import ( +from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.protocol import EntityActionHandler -from airweave.platform.sync.processors import ChunkEmbedProcessor +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.platform.entities import BaseEntity -_processor = ChunkEmbedProcessor() - _RETRYABLE_EXCEPTIONS: tuple = ( ConnectionError, TimeoutError, @@ -42,9 +40,13 @@ class DestinationHandler(EntityActionHandler): """Handler that chunks/embeds entities and inserts into destinations.""" - def __init__(self, destinations: List[BaseDestination]): - """Initialize handler with destinations.""" + def __init__( + self, + destinations: List[BaseDestination], + processor: ChunkEmbedProcessorProtocol, + ) -> None: self._destinations = destinations + self._processor = processor @property def name(self) -> str: @@ -151,7 +153,7 @@ async def _do_process_and_insert( """Process entities through ChunkEmbedProcessor and insert into destinations.""" copies = [e.model_copy(deep=True) for e in entities] proc_start = asyncio.get_running_loop().time() - processed = await _processor.process(copies, sync_context, runtime) + processed = await self._processor.process(copies, sync_context, runtime) proc_elapsed = asyncio.get_running_loop().time() - proc_start if proc_elapsed > 10: sync_context.logger.warning( diff --git a/backend/airweave/platform/sync/handlers/entity_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py similarity index 90% rename from backend/airweave/platform/sync/handlers/entity_postgres.py rename to backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py index a53adec33..04144d951 100644 --- a/backend/airweave/platform/sync/handlers/entity_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py @@ -9,33 +9,28 @@ from sqlalchemy.ext.asyncio import AsyncSession -from airweave import crud, schemas +from airweave import schemas from airweave.db.session import get_db_context -from airweave.platform.sync.actions.entity.types import ( +from airweave.domains.entities.protocols import EntityRepositoryProtocol +from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class EntityPostgresHandler(EntityActionHandler): - """Handler for PostgreSQL entity metadata. + """Handler for PostgreSQL entity metadata.""" - Stores entity records with: - - entity_id: Unique identifier from source - - entity_definition_short_name: Type classification - - hash: Content hash for change detection - - sync_id, sync_job_id: Sync tracking - - Runs AFTER destination handlers succeed (dispatcher handles ordering). - """ + def __init__(self, entity_repo: EntityRepositoryProtocol) -> None: + self._entity_repo = entity_repo @property def name(self) -> str: @@ -189,7 +184,7 @@ async def _do_inserts( sync_context.logger.debug( f"[EntityPostgres] Upserting {len(create_objs)} (sample: {sample_ids})" ) - await crud.entity.bulk_create(db, objs=create_objs, ctx=sync_context) + await self._entity_repo.bulk_create(db, objs=create_objs, ctx=sync_context) async def _do_updates( self, @@ -215,7 +210,7 @@ async def _do_updates( update_pairs.sort(key=lambda p: p[0]) sync_context.logger.debug(f"[EntityPostgres] Updating {len(update_pairs)} hashes") - await crud.entity.bulk_update_hash(db, rows=update_pairs) + await self._entity_repo.bulk_update_hash(db, rows=update_pairs) async def _do_deletes( self, @@ -237,7 +232,7 @@ async def _do_deletes( return sync_context.logger.debug(f"[EntityPostgres] Deleting {len(db_ids)} records") - await crud.entity.bulk_remove(db, ids=db_ids, ctx=sync_context) + await self._entity_repo.bulk_remove(db, ids=db_ids, ctx=sync_context) # ------------------------------------------------------------------------- # Private: Orphan Cleanup @@ -252,7 +247,7 @@ async def _do_orphan_cleanup( sync_context.logger.info(f"[EntityPostgres] Cleaning {len(orphan_entity_ids)} orphans") async with get_db_context() as db: - entity_map = await crud.entity.bulk_get_by_entity_and_sync( + entity_map = await self._entity_repo.bulk_get_by_entity_and_sync( db=db, entity_ids=orphan_entity_ids, sync_id=sync_context.sync.id, @@ -263,7 +258,7 @@ async def _do_orphan_cleanup( return db_ids = [e.id for e in entity_map.values()] - await crud.entity.bulk_remove(db=db, ids=db_ids, ctx=sync_context) + await self._entity_repo.bulk_remove(db=db, ids=db_ids, ctx=sync_context) await db.commit() sync_context.logger.info(f"[EntityPostgres] Deleted {len(db_ids)} orphan records") @@ -280,7 +275,7 @@ async def _fetch_existing_map( ) -> Dict[Tuple[str, str], Any]: """Fetch existing DB records for update/delete actions.""" entity_requests = [(a.entity_id, a.entity_definition_short_name) for a in actions] - return await crud.entity.bulk_get_by_entity_sync_and_definition( + return await self._entity_repo.bulk_get_by_entity_sync_and_definition( db=db, sync_id=sync_context.sync.id, entity_requests=entity_requests ) diff --git a/backend/airweave/platform/sync/handlers/protocol.py b/backend/airweave/domains/sync_pipeline/handlers/protocol.py similarity index 92% rename from backend/airweave/platform/sync/handlers/protocol.py rename to backend/airweave/domains/sync_pipeline/handlers/protocol.py index 1a5d7b920..80b510172 100644 --- a/backend/airweave/platform/sync/handlers/protocol.py +++ b/backend/airweave/domains/sync_pipeline/handlers/protocol.py @@ -3,16 +3,16 @@ from typing import TYPE_CHECKING, Any, List, Protocol, runtime_checkable if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime - from airweave.platform.sync.actions.access_control import ( + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.types.access_control_actions import ( ACActionBatch, ACDeleteAction, ACInsertAction, ACUpdateAction, ACUpsertAction, ) - from airweave.platform.sync.actions.entity import ( + from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, diff --git a/backend/airweave/platform/sync/orchestrator.py b/backend/airweave/domains/sync_pipeline/orchestrator.py similarity index 98% rename from backend/airweave/platform/sync/orchestrator.py rename to backend/airweave/domains/sync_pipeline/orchestrator.py index 02c98be40..127ff24d4 100644 --- a/backend/airweave/platform/sync/orchestrator.py +++ b/backend/airweave/domains/sync_pipeline/orchestrator.py @@ -18,12 +18,12 @@ UsageLimitExceededError, ) from airweave.domains.usage.types import ActionType -from airweave.platform.contexts import SyncContext -from airweave.platform.contexts.runtime import SyncRuntime -from airweave.platform.sync.access_control_pipeline import AccessControlPipeline -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.sync.stream import AsyncSourceStream -from airweave.platform.sync.worker_pool import AsyncWorkerPool +from airweave.domains.sync_pipeline.contexts import SyncContext +from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime +from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.stream import AsyncSourceStream +from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.platform.utils.error_utils import get_error_message diff --git a/backend/airweave/domains/sync_pipeline/pipeline/__init__.py b/backend/airweave/domains/sync_pipeline/pipeline/__init__.py new file mode 100644 index 000000000..a440c2339 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/pipeline/__init__.py @@ -0,0 +1 @@ +"""Sync pipeline utilities — stateful per-sync components.""" diff --git a/backend/airweave/platform/sync/pipeline/acl_membership_tracker.py b/backend/airweave/domains/sync_pipeline/pipeline/acl_membership_tracker.py similarity index 100% rename from backend/airweave/platform/sync/pipeline/acl_membership_tracker.py rename to backend/airweave/domains/sync_pipeline/pipeline/acl_membership_tracker.py diff --git a/backend/airweave/platform/sync/pipeline/cleanup_service.py b/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py similarity index 95% rename from backend/airweave/platform/sync/pipeline/cleanup_service.py rename to backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py index f06638d8d..9014b94e5 100644 --- a/backend/airweave/platform/sync/pipeline/cleanup_service.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING, Any, Dict, List from airweave.platform.entities._base import FileEntity -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class CleanupService: diff --git a/backend/airweave/platform/sync/pipeline/entity_tracker.py b/backend/airweave/domains/sync_pipeline/pipeline/entity_tracker.py similarity index 100% rename from backend/airweave/platform/sync/pipeline/entity_tracker.py rename to backend/airweave/domains/sync_pipeline/pipeline/entity_tracker.py diff --git a/backend/airweave/platform/sync/pipeline/hash_computer.py b/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py similarity index 97% rename from backend/airweave/platform/sync/pipeline/hash_computer.py rename to backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py index 0aa90219b..c341a88cf 100644 --- a/backend/airweave/platform/sync/pipeline/hash_computer.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py @@ -7,12 +7,12 @@ from airweave.core.shared_models import AirweaveFieldFlag from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class HashComputer: diff --git a/backend/airweave/platform/sync/pipeline/text_builder.py b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py similarity index 98% rename from backend/airweave/platform/sync/pipeline/text_builder.py rename to backend/airweave/domains/sync_pipeline/pipeline/text_builder.py index 8aeedf8d8..c0d58285f 100644 --- a/backend/airweave/platform/sync/pipeline/text_builder.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py @@ -7,12 +7,12 @@ from airweave.core.shared_models import AirweaveFieldFlag from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity, WebEntity -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.sync.file_types import SUPPORTED_FILE_EXTENSIONS +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class TextualRepresentationBuilder: diff --git a/backend/airweave/domains/sync_pipeline/processors/__init__.py b/backend/airweave/domains/sync_pipeline/processors/__init__.py new file mode 100644 index 000000000..452cb8358 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/processors/__init__.py @@ -0,0 +1,5 @@ +"""Sync pipeline processors — chunking and embedding.""" + +from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor + +__all__ = ["ChunkEmbedProcessor"] diff --git a/backend/airweave/platform/sync/processors/chunk_embed.py b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py similarity index 96% rename from backend/airweave/platform/sync/processors/chunk_embed.py rename to backend/airweave/domains/sync_pipeline/processors/chunk_embed.py index 67ed160e4..2cfc9a3df 100644 --- a/backend/airweave/platform/sync/processors/chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py @@ -15,13 +15,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple from airweave.platform.entities._base import BaseEntity, CodeFileEntity -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.pipeline.text_builder import text_builder -from airweave.platform.sync.processors.utils import filter_empty_representations +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder +from airweave.domains.sync_pipeline.processors.utils import filter_empty_representations if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime class ChunkEmbedProcessor: diff --git a/backend/airweave/platform/sync/processors/utils.py b/backend/airweave/domains/sync_pipeline/processors/utils.py similarity index 88% rename from backend/airweave/platform/sync/processors/utils.py rename to backend/airweave/domains/sync_pipeline/processors/utils.py index a42c0233e..ddef0c364 100644 --- a/backend/airweave/platform/sync/processors/utils.py +++ b/backend/airweave/domains/sync_pipeline/processors/utils.py @@ -5,8 +5,8 @@ from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime async def filter_empty_representations( diff --git a/backend/airweave/domains/sync_pipeline/protocols.py b/backend/airweave/domains/sync_pipeline/protocols.py index 5e11fdf08..6680dad02 100644 --- a/backend/airweave/domains/sync_pipeline/protocols.py +++ b/backend/airweave/domains/sync_pipeline/protocols.py @@ -8,14 +8,14 @@ from airweave import schemas from airweave.platform.entities._base import BaseEntity -from airweave.platform.sync.actions.entity.types import EntityActionBatch +from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch if TYPE_CHECKING: from airweave.core.context import BaseContext - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime - from airweave.platform.sync.config import SyncConfig - from airweave.platform.sync.orchestrator import SyncOrchestrator + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.config import SyncConfig + from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator class ChunkEmbedProcessorProtocol(Protocol): diff --git a/backend/airweave/platform/sync/stream.py b/backend/airweave/domains/sync_pipeline/stream.py similarity index 100% rename from backend/airweave/platform/sync/stream.py rename to backend/airweave/domains/sync_pipeline/stream.py diff --git a/backend/airweave/domains/sync_pipeline/subscribers/__init__.py b/backend/airweave/domains/sync_pipeline/subscribers/__init__.py new file mode 100644 index 000000000..454b7bfd1 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/subscribers/__init__.py @@ -0,0 +1,5 @@ +"""Sync pipeline subscribers — event handlers for sync progress.""" + +from airweave.domains.sync_pipeline.subscribers.progress_relay import SyncProgressRelay + +__all__ = ["SyncProgressRelay"] diff --git a/backend/airweave/platform/sync/subscribers/progress_relay.py b/backend/airweave/domains/sync_pipeline/subscribers/progress_relay.py similarity index 100% rename from backend/airweave/platform/sync/subscribers/progress_relay.py rename to backend/airweave/domains/sync_pipeline/subscribers/progress_relay.py diff --git a/backend/tests/unit/platform/sync/pipeline/test_acl_membership_tracker.py b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py similarity index 99% rename from backend/tests/unit/platform/sync/pipeline/test_acl_membership_tracker.py rename to backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py index e6ac16cfb..133def56a 100644 --- a/backend/tests/unit/platform/sync/pipeline/test_acl_membership_tracker.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from uuid import uuid4 -from airweave.platform.sync.pipeline.acl_membership_tracker import ACLMembershipTracker +from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker @pytest.fixture diff --git a/backend/tests/unit/platform/sync/pipeline/test_acl_reconciliation.py b/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py similarity index 76% rename from backend/tests/unit/platform/sync/pipeline/test_acl_reconciliation.py rename to backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py index 11b375c1f..dbe6eab10 100644 --- a/backend/tests/unit/platform/sync/pipeline/test_acl_reconciliation.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py @@ -19,7 +19,7 @@ import pytest from airweave.platform.access_control.schemas import ACLChangeType, MembershipChange -from airweave.platform.sync.access_control_pipeline import AccessControlPipeline +from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline # --------------------------------------------------------------------------- @@ -81,6 +81,7 @@ def _make_pipeline(): resolver=MagicMock(), dispatcher=MagicMock(), tracker=MagicMock(), + acl_repo=MagicMock(), ) @@ -109,18 +110,17 @@ async def test_applies_adds_and_removes(self): modified_group_ids={"group-A", "group-B"}, ) - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud: - mock_crud.access_control_membership.upsert = AsyncMock() - mock_crud.access_control_membership.delete_by_key = AsyncMock() + pipeline._acl_repo.upsert = AsyncMock() + pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes( + db, result, source, ctx + ) assert adds == 2 assert removes == 1 - assert mock_crud.access_control_membership.upsert.call_count == 2 - assert mock_crud.access_control_membership.delete_by_key.call_count == 1 + assert pipeline._acl_repo.upsert.call_count == 2 + assert pipeline._acl_repo.delete_by_key.call_count == 1 @pytest.mark.asyncio async def test_basic_flags_does_not_reconcile(self): @@ -143,19 +143,15 @@ async def test_basic_flags_does_not_reconcile(self): incremental_values=False, ) - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud: - mock_crud.access_control_membership.upsert = AsyncMock() - mock_crud.access_control_membership.delete_by_key = AsyncMock() - mock_crud.access_control_membership.get_memberships_by_groups = AsyncMock() + pipeline._acl_repo.upsert = AsyncMock() + pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes( + db, result, source, ctx + ) assert adds == 2 assert removes == 0 - # No reconciliation — get_memberships_by_groups never called - mock_crud.access_control_membership.get_memberships_by_groups.assert_not_called() @pytest.mark.asyncio async def test_upsert_passes_correct_fields(self): @@ -174,14 +170,13 @@ async def test_upsert_passes_correct_fields(self): modified_group_ids=set(), ) - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud: - mock_crud.access_control_membership.upsert = AsyncMock() + pipeline._acl_repo.upsert = AsyncMock() - await pipeline._apply_membership_changes( - db, result, source, ctx - ) + await pipeline._apply_membership_changes( + db, result, source, ctx + ) - mock_crud.access_control_membership.upsert.assert_called_once_with( + pipeline._acl_repo.upsert.assert_called_once_with( db, member_id="alice@acme.com", member_type="user", @@ -202,13 +197,12 @@ async def test_empty_changes_returns_zero(self): result = FakeDirSyncResult(changes=[], modified_group_ids=set()) - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud: - mock_crud.access_control_membership.upsert = AsyncMock() - mock_crud.access_control_membership.delete_by_key = AsyncMock() + pipeline._acl_repo.upsert = AsyncMock() + pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes( + db, result, source, ctx + ) assert adds == 0 assert removes == 0 @@ -244,18 +238,17 @@ async def test_deleted_groups_remove_all_memberships(self): ) source.get_acl_changes.return_value = result - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud, \ - patch("airweave.platform.sync.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - mock_crud.access_control_membership.delete_by_group = AsyncMock(return_value=5) + pipeline._acl_repo.delete_by_group = AsyncMock(return_value=5) total = await pipeline._process_incremental(source, ctx, runtime) assert total == 10 # 5 members x 2 groups - assert mock_crud.access_control_membership.delete_by_group.call_count == 2 + assert pipeline._acl_repo.delete_by_group.call_count == 2 @pytest.mark.asyncio async def test_no_changes_returns_zero_and_updates_cookie(self): @@ -328,24 +321,22 @@ async def test_full_flow_with_adds_removes_and_deletes(self): ) source.get_acl_changes.return_value = result - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud, \ - patch("airweave.platform.sync.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - mock_crud.access_control_membership.upsert = AsyncMock() - mock_crud.access_control_membership.delete_by_key = AsyncMock() - mock_crud.access_control_membership.delete_by_group = AsyncMock(return_value=3) + pipeline._acl_repo.upsert = AsyncMock() + pipeline._acl_repo.delete_by_key = AsyncMock() + pipeline._acl_repo.delete_by_group = AsyncMock(return_value=3) total = await pipeline._process_incremental(source, ctx, runtime) # 2 adds + 1 remove + 3 group-deletion removals = 6 assert total == 6 - assert mock_crud.access_control_membership.upsert.call_count == 2 - assert mock_crud.access_control_membership.delete_by_key.call_count == 1 - assert mock_crud.access_control_membership.delete_by_group.call_count == 1 - # Cookie should be updated + assert pipeline._acl_repo.upsert.call_count == 2 + assert pipeline._acl_repo.delete_by_key.call_count == 1 + assert pipeline._acl_repo.delete_by_group.call_count == 1 assert runtime.cursor.data["acl_dirsync_cookie"] == "final_cookie" @pytest.mark.asyncio @@ -371,17 +362,13 @@ async def test_basic_flags_does_not_reconcile_in_full_flow(self): ) source.get_acl_changes.return_value = result - with patch("airweave.platform.sync.access_control_pipeline.crud") as mock_crud, \ - patch("airweave.platform.sync.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) - mock_crud.access_control_membership.upsert = AsyncMock() - mock_crud.access_control_membership.get_memberships_by_groups = AsyncMock() + pipeline._acl_repo.upsert = AsyncMock() total = await pipeline._process_incremental(source, ctx, runtime) - # Only the 1 ADD — no reconciliation removals assert total == 1 - mock_crud.access_control_membership.get_memberships_by_groups.assert_not_called() diff --git a/backend/tests/unit/platform/sync/processors/test_chunk_embed.py b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py similarity index 96% rename from backend/tests/unit/platform/sync/processors/test_chunk_embed.py rename to backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py index cfb784a7d..8ea1d9a10 100644 --- a/backend/tests/unit/platform/sync/processors/test_chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py @@ -4,7 +4,7 @@ import pytest -from airweave.platform.sync.processors.chunk_embed import ChunkEmbedProcessor +from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor @pytest.fixture @@ -64,7 +64,7 @@ async def test_chunk_textual_entities_uses_semantic_chunker( self, processor, mock_sync_context, mock_runtime, mock_entity ): """Test textual entities routed to SemanticChunker.""" - with patch('airweave.platform.sync.processors.chunk_embed.text_builder') as mock_builder, \ + with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockSemanticChunker, \ patch.object(processor, '_embed_entities', new_callable=AsyncMock): @@ -310,7 +310,7 @@ def create_chunk(deep=False): return_value=[MagicMock(), MagicMock()] ) - with patch('airweave.platform.sync.processors.chunk_embed.text_builder') as mock_builder, \ + with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: # Setup mocks @@ -354,7 +354,7 @@ def create_chunk(deep=False): mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - with patch('airweave.platform.sync.processors.chunk_embed.text_builder') as mock_builder, \ + with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) @@ -377,7 +377,7 @@ async def test_skips_entities_without_text( mock_entity.textual_representation = None # No text mock_entity.airweave_system_metadata = MagicMock() - with patch('airweave.platform.sync.processors.chunk_embed.text_builder') as mock_builder: + with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) result = await processor.process([mock_entity], mock_sync_context, mock_runtime) @@ -395,7 +395,7 @@ async def test_handles_empty_chunks_from_chunker( mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() - with patch('airweave.platform.sync.processors.chunk_embed.text_builder') as mock_builder, \ + with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) diff --git a/backend/tests/unit/platform/sync/pipeline/test_cleanup_service.py b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py similarity index 98% rename from backend/tests/unit/platform/sync/pipeline/test_cleanup_service.py rename to backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py index 80f53996c..74a9f687f 100644 --- a/backend/tests/unit/platform/sync/pipeline/test_cleanup_service.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py @@ -17,8 +17,8 @@ from airweave.platform.entities._airweave_field import AirweaveField from airweave.platform.entities._base import BaseEntity, DeletionEntity, FileEntity -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.pipeline.cleanup_service import cleanup_service +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.pipeline.cleanup_service import cleanup_service # Test entity classes diff --git a/backend/tests/unit/platform/sync/sync_config/test_base.py b/backend/airweave/domains/sync_pipeline/tests/test_config_base.py similarity index 99% rename from backend/tests/unit/platform/sync/sync_config/test_base.py rename to backend/airweave/domains/sync_pipeline/tests/test_config_base.py index db41dc59a..69ec1ed69 100644 --- a/backend/tests/unit/platform/sync/sync_config/test_base.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_config_base.py @@ -6,7 +6,7 @@ import pytest -from airweave.platform.sync.config.base import ( +from airweave.domains.sync_pipeline.config.base import ( BehaviorConfig, CursorConfig, DestinationConfig, diff --git a/backend/tests/unit/platform/sync/sync_config/test_builder.py b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py similarity index 97% rename from backend/tests/unit/platform/sync/sync_config/test_builder.py rename to backend/airweave/domains/sync_pipeline/tests/test_config_builder.py index d6062f7ab..fc9d01261 100644 --- a/backend/tests/unit/platform/sync/sync_config/test_builder.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py @@ -3,14 +3,14 @@ import os from unittest.mock import patch -from airweave.platform.sync.config.base import ( +from airweave.domains.sync_pipeline.config.base import ( BehaviorConfig, CursorConfig, DestinationConfig, HandlerConfig, SyncConfig, ) -from airweave.platform.sync.config.builder import SyncConfigBuilder +from airweave.domains.sync_pipeline.config.builder import SyncConfigBuilder def _clean_env(): diff --git a/backend/tests/unit/platform/sync/handlers/test_destination_handler.py b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py similarity index 88% rename from backend/tests/unit/platform/sync/handlers/test_destination_handler.py rename to backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py index f48bc3bfd..288b897a9 100644 --- a/backend/tests/unit/platform/sync/handlers/test_destination_handler.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py @@ -12,8 +12,8 @@ import pytest -from airweave.platform.sync.exceptions import SyncFailureError -from airweave.platform.sync.handlers.destination import DestinationHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler def _make_mock_destination(soft_fail=False): @@ -45,7 +45,7 @@ class TestExecuteWithRetryTimeout: async def test_timeout_error_is_retried(self): """TimeoutError should be caught and retried up to max_retries times.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() call_count = 0 @@ -55,7 +55,7 @@ async def failing_operation(): call_count += 1 raise TimeoutError("feed timed out") - with patch("airweave.platform.sync.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): with pytest.raises(SyncFailureError, match="Destination unavailable"): await handler._execute_with_retry( operation=failing_operation, @@ -72,7 +72,7 @@ async def failing_operation(): async def test_asyncio_timeout_error_is_retried(self): """asyncio.TimeoutError (subclass of TimeoutError) should also be retried.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() call_count = 0 @@ -82,7 +82,7 @@ async def failing_operation(): call_count += 1 raise asyncio.TimeoutError() - with patch("airweave.platform.sync.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): with pytest.raises(SyncFailureError, match="Destination unavailable"): await handler._execute_with_retry( operation=failing_operation, @@ -98,7 +98,7 @@ async def failing_operation(): async def test_timeout_succeeds_on_retry(self): """If operation succeeds on retry, no error is raised.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() call_count = 0 @@ -110,7 +110,7 @@ async def flaky_operation(): raise TimeoutError("temporary failure") return "success" - with patch("airweave.platform.sync.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): result = await handler._execute_with_retry( operation=flaky_operation, operation_name="insert_MockDestination", @@ -126,13 +126,13 @@ async def flaky_operation(): async def test_retry_logs_warning_on_each_failure(self): """Each retry should log a warning with attempt number.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() async def failing_operation(): raise TimeoutError("feed timed out") - with patch("airweave.platform.sync.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): with pytest.raises(SyncFailureError): await handler._execute_with_retry( operation=failing_operation, @@ -151,7 +151,7 @@ async def failing_operation(): async def test_non_retryable_exception_fails_immediately(self): """Non-retryable exceptions should fail immediately with SyncFailureError.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() call_count = 0 @@ -181,7 +181,7 @@ class TestTimingLogs: async def test_slow_operation_logs_warning(self): """Operations taking >10s should log a warning.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() # Mock the event loop time to simulate a 15-second operation @@ -212,7 +212,7 @@ async def slow_operation(): async def test_fast_operation_does_not_log_warning(self): """Operations completing in <10s should not log a warning.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) + handler = DestinationHandler([dest], processor=MagicMock()) ctx = _make_mock_sync_context() # Mock the event loop time to simulate a 0.5-second operation @@ -242,12 +242,11 @@ async def fast_operation(): async def test_slow_processing_logs_warning(self): """Content processing (chunking/embedding) taking >10s should log a warning.""" dest = _make_mock_destination() - handler = DestinationHandler([dest]) - ctx = _make_mock_sync_context() - mock_processor = MagicMock() mock_processor.__class__.__name__ = "ChunkEmbedProcessor" mock_processor.process = AsyncMock(return_value=[]) + handler = DestinationHandler([dest], processor=mock_processor) + ctx = _make_mock_sync_context() time_values = [0.0, 15.0] time_iter = iter(time_values) @@ -258,8 +257,7 @@ async def test_slow_processing_logs_warning(self): mock_entity = MagicMock() mock_runtime = MagicMock() - with patch("asyncio.get_running_loop", return_value=mock_loop), \ - patch("airweave.platform.sync.handlers.destination._processor", mock_processor): + with patch("asyncio.get_running_loop", return_value=mock_loop): await handler._do_process_and_insert([mock_entity], ctx, mock_runtime) warning_calls = ctx.logger.warning.call_args_list diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py index 68ae540d6..52600f3b6 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py @@ -12,12 +12,12 @@ BaseEntity, DeletionEntity, ) -from airweave.platform.sync.actions.entity.types import ( +from airweave.domains.sync_pipeline.types.entity_actions import ( EntityInsertAction, EntityKeepAction, EntityUpdateAction, ) -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # --------------------------------------------------------------------------- diff --git a/backend/airweave/domains/sync_pipeline/tests/test_factory.py b/backend/airweave/domains/sync_pipeline/tests/test_factory.py index 9c0b35fe7..6553b89b8 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_factory.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_factory.py @@ -17,6 +17,8 @@ def _build_factory(**overrides): "dense_embedder": MagicMock(), "sparse_embedder": MagicMock(), "entity_repo": MagicMock(), + "acl_repo": MagicMock(), + "processor": MagicMock(), } defaults.update(overrides) return SyncFactory(**defaults) @@ -28,7 +30,7 @@ def _build_factory(**overrides): def test_constructor_stores_all_deps(): - """All six injected deps are stored on the instance.""" + """All injected deps are stored on the instance.""" deps = { "sc_repo": MagicMock(), "event_bus": MagicMock(), @@ -36,6 +38,8 @@ def test_constructor_stores_all_deps(): "dense_embedder": MagicMock(), "sparse_embedder": MagicMock(), "entity_repo": MagicMock(), + "acl_repo": MagicMock(), + "processor": MagicMock(), } f = SyncFactory(**deps) assert f._sc_repo is deps["sc_repo"] @@ -44,6 +48,8 @@ def test_constructor_stores_all_deps(): assert f._dense_embedder is deps["dense_embedder"] assert f._sparse_embedder is deps["sparse_embedder"] assert f._entity_repo is deps["entity_repo"] + assert f._acl_repo is deps["acl_repo"] + assert f._processor is deps["processor"] # --------------------------------------------------------------------------- diff --git a/backend/airweave/platform/sync/subscribers/tests/test_progress_relay.py b/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py similarity index 99% rename from backend/airweave/platform/sync/subscribers/tests/test_progress_relay.py rename to backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py index 8eb6283d5..09a8e72f1 100644 --- a/backend/airweave/platform/sync/subscribers/tests/test_progress_relay.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py @@ -14,7 +14,7 @@ TypeActionCounts, ) from airweave.core.shared_models import SyncJobStatus -from airweave.platform.sync.subscribers.progress_relay import SyncProgressRelay +from airweave.domains.sync_pipeline.subscribers.progress_relay import SyncProgressRelay ORG_ID = UUID("00000000-0000-0000-0000-000000000001") SYNC_ID = uuid4() diff --git a/backend/airweave/domains/sync_pipeline/types/__init__.py b/backend/airweave/domains/sync_pipeline/types/__init__.py new file mode 100644 index 000000000..a42fc90a7 --- /dev/null +++ b/backend/airweave/domains/sync_pipeline/types/__init__.py @@ -0,0 +1 @@ +"""Sync pipeline types — action dataclasses for entity and access control.""" diff --git a/backend/airweave/platform/sync/actions/access_control/types.py b/backend/airweave/domains/sync_pipeline/types/access_control_actions.py similarity index 100% rename from backend/airweave/platform/sync/actions/access_control/types.py rename to backend/airweave/domains/sync_pipeline/types/access_control_actions.py diff --git a/backend/airweave/platform/sync/actions/entity/types.py b/backend/airweave/domains/sync_pipeline/types/entity_actions.py similarity index 100% rename from backend/airweave/platform/sync/actions/entity/types.py rename to backend/airweave/domains/sync_pipeline/types/entity_actions.py diff --git a/backend/airweave/platform/sync/worker_pool.py b/backend/airweave/domains/sync_pipeline/worker_pool.py similarity index 100% rename from backend/airweave/platform/sync/worker_pool.py rename to backend/airweave/domains/sync_pipeline/worker_pool.py diff --git a/backend/airweave/domains/syncs/fakes/sync_job_service.py b/backend/airweave/domains/syncs/fakes/sync_job_service.py index a79dec869..ff4c9cee4 100644 --- a/backend/airweave/domains/syncs/fakes/sync_job_service.py +++ b/backend/airweave/domains/syncs/fakes/sync_job_service.py @@ -6,7 +6,7 @@ from airweave.api.context import ApiContext from airweave.core.shared_models import SyncJobStatus -from airweave.platform.sync.pipeline.entity_tracker import SyncStats +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats class FakeSyncJobService: diff --git a/backend/airweave/domains/syncs/fakes/sync_service.py b/backend/airweave/domains/syncs/fakes/sync_service.py index b0bacfd02..d31050fb6 100644 --- a/backend/airweave/domains/syncs/fakes/sync_service.py +++ b/backend/airweave/domains/syncs/fakes/sync_service.py @@ -4,7 +4,7 @@ from airweave import schemas from airweave.api.context import ApiContext -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig class FakeSyncService: diff --git a/backend/airweave/domains/syncs/protocols.py b/backend/airweave/domains/syncs/protocols.py index cb33446e0..927e255c1 100644 --- a/backend/airweave/domains/syncs/protocols.py +++ b/backend/airweave/domains/syncs/protocols.py @@ -15,8 +15,8 @@ from airweave.models.sync import Sync from airweave.models.sync_cursor import SyncCursor from airweave.models.sync_job import SyncJob -from airweave.platform.sync.config import SyncConfig -from airweave.platform.sync.pipeline.entity_tracker import SyncStats +from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob from airweave.schemas.sync import SyncCreate, SyncUpdate from airweave.schemas.sync_job import SyncJobCreate, SyncJobUpdate diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index 4e4889208..9d6886163 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -12,7 +12,7 @@ from airweave.db.session import get_db_context from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol from airweave.domains.syncs.protocols import SyncJobServiceProtocol, SyncServiceProtocol -from airweave.platform.sync.config import SyncConfig +from airweave.domains.sync_pipeline.config import SyncConfig class SyncService(SyncServiceProtocol): diff --git a/backend/airweave/domains/syncs/sync_job_service.py b/backend/airweave/domains/syncs/sync_job_service.py index 3c21a9b38..33de58434 100644 --- a/backend/airweave/domains/syncs/sync_job_service.py +++ b/backend/airweave/domains/syncs/sync_job_service.py @@ -15,7 +15,7 @@ from airweave.db.session import get_db_context from airweave.domains.syncs.protocols import SyncJobRepositoryProtocol, SyncJobServiceProtocol from airweave.domains.syncs.types import StatsUpdate, TimestampUpdate -from airweave.platform.sync.pipeline.entity_tracker import SyncStats +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats class SyncJobService(SyncJobServiceProtocol): diff --git a/backend/airweave/domains/syncs/tests/test_sync_job_service.py b/backend/airweave/domains/syncs/tests/test_sync_job_service.py index a6f1088e1..a7a5ec9cd 100644 --- a/backend/airweave/domains/syncs/tests/test_sync_job_service.py +++ b/backend/airweave/domains/syncs/tests/test_sync_job_service.py @@ -15,7 +15,7 @@ from airweave.core.shared_models import SyncJobStatus from airweave.domains.syncs.sync_job_service import SyncJobService from airweave.domains.syncs.types import StatsUpdate, TimestampUpdate -from airweave.platform.sync.pipeline.entity_tracker import SyncStats +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats NOW = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc) diff --git a/backend/airweave/platform/access_control/broker.py b/backend/airweave/platform/access_control/broker.py index bdcf6d3de..e56173811 100644 --- a/backend/airweave/platform/access_control/broker.py +++ b/backend/airweave/platform/access_control/broker.py @@ -5,21 +5,16 @@ from sqlalchemy.ext.asyncio import AsyncSession -from airweave import crud +from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol from airweave.platform.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl class AccessBroker: - """Resolves user access context by expanding group memberships. + """Resolves user access context by expanding group memberships.""" - Source-agnostic: works for SharePoint, Google Drive, etc. - Handles both direct user-group and nested group-group relationships. - - Access control is only applied when at least one source in the collection - has supports_access_control=True. For collections with only non-AC sources, - no filtering is applied (all entities visible to everyone). - """ + def __init__(self, acl_repo: AccessControlMembershipRepositoryProtocol) -> None: + self._acl_repo = acl_repo async def resolve_access_context( self, db: AsyncSession, user_principal: str, organization_id: UUID @@ -44,7 +39,7 @@ async def resolve_access_context( AccessContext with fully expanded principals """ # Query direct user-group memberships (member_type="user") - memberships = await crud.access_control_membership.get_by_member( + memberships = await self._acl_repo.get_by_member( db=db, member_id=user_principal, member_type="user", organization_id=organization_id ) @@ -109,7 +104,7 @@ async def resolve_access_context_for_collection( return None # Query user-group memberships scoped to collection (member_type="user") - memberships = await crud.access_control_membership.get_by_member_and_collection( + memberships = await self._acl_repo.get_by_member_and_collection( db=db, member_id=user_principal, member_type="user", @@ -207,7 +202,7 @@ async def _expand_group_memberships( visited.add(current_group) # Query for group-to-group memberships via CRUD layer (member_type="group") - nested_memberships = await crud.access_control_membership.get_by_member( + nested_memberships = await self._acl_repo.get_by_member( db=db, member_id=current_group, member_type="group", organization_id=organization_id ) @@ -258,5 +253,11 @@ def check_entity_access( return bool(access_context.all_principals & set(entity_access.viewers)) -# Singleton shared instance -access_broker = AccessBroker() +def _default_access_broker() -> "AccessBroker": + """Create a default AccessBroker backed by the real repository.""" + from airweave.domains.access_control.repository import AccessControlMembershipRepository + + return AccessBroker(acl_repo=AccessControlMembershipRepository()) + + +access_broker = _default_access_broker() diff --git a/backend/airweave/platform/builders/__init__.py b/backend/airweave/platform/builders/__init__.py deleted file mode 100644 index edcdd9044..000000000 --- a/backend/airweave/platform/builders/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Context builders for platform operations. - -Builders: -- SyncContextBuilder: Builds flat SyncContext with all components -""" - -from airweave.platform.builders.sync import SyncContextBuilder - -__all__ = [ - "SyncContextBuilder", -] diff --git a/backend/airweave/platform/chunkers/code.py b/backend/airweave/platform/chunkers/code.py index f488808f1..aabb3b0df 100644 --- a/backend/airweave/platform/chunkers/code.py +++ b/backend/airweave/platform/chunkers/code.py @@ -5,8 +5,8 @@ from airweave.core.logging import logger from airweave.platform.chunkers._base import BaseChunker from airweave.platform.chunkers.tiktoken_compat import SafeEncoding -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.tokenizers import TikTokenTokenizer, get_tokenizer diff --git a/backend/airweave/platform/chunkers/semantic.py b/backend/airweave/platform/chunkers/semantic.py index be8df284d..f3b5386ea 100644 --- a/backend/airweave/platform/chunkers/semantic.py +++ b/backend/airweave/platform/chunkers/semantic.py @@ -5,8 +5,8 @@ from airweave.core.logging import logger from airweave.platform.chunkers._base import BaseChunker from airweave.platform.chunkers.tiktoken_compat import SafeEncoding -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.tokenizers import TikTokenTokenizer, get_tokenizer diff --git a/backend/airweave/platform/contexts/__init__.py b/backend/airweave/platform/contexts/__init__.py deleted file mode 100644 index 1604d1a6c..000000000 --- a/backend/airweave/platform/contexts/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Contexts for platform operations. - -Context Types: -- SyncContext: Frozen data for sync operations (inherits BaseContext) -- SyncRuntime: Live services for a sync run (source, cursor, trackers, etc.) -""" - -from airweave.platform.contexts.runtime import SyncRuntime -from airweave.platform.contexts.sync import SyncContext - -__all__ = [ - "SyncContext", - "SyncRuntime", -] diff --git a/backend/airweave/platform/converters/html_converter.py b/backend/airweave/platform/converters/html_converter.py index 2f3167b04..c26421a40 100644 --- a/backend/airweave/platform/converters/html_converter.py +++ b/backend/airweave/platform/converters/html_converter.py @@ -5,8 +5,8 @@ from airweave.core.logging import logger from airweave.platform.converters._base import BaseTextConverter -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError class HtmlConverter(BaseTextConverter): diff --git a/backend/airweave/platform/converters/text_extractors/docx.py b/backend/airweave/platform/converters/text_extractors/docx.py index a6fecc789..b4f23611a 100644 --- a/backend/airweave/platform/converters/text_extractors/docx.py +++ b/backend/airweave/platform/converters/text_extractors/docx.py @@ -12,7 +12,7 @@ from typing import Any, Optional from airweave.core.logging import logger -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # Minimum total characters to consider the extraction successful. MIN_TOTAL_CHARS = 50 diff --git a/backend/airweave/platform/converters/text_extractors/pdf.py b/backend/airweave/platform/converters/text_extractors/pdf.py index b1506a112..94b2405f0 100644 --- a/backend/airweave/platform/converters/text_extractors/pdf.py +++ b/backend/airweave/platform/converters/text_extractors/pdf.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from airweave.core.logging import logger -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # Minimum characters per page to consider it "has text layer". # Pages below this threshold are treated as image-only. diff --git a/backend/airweave/platform/converters/text_extractors/pptx.py b/backend/airweave/platform/converters/text_extractors/pptx.py index 482420a12..0e5ceb828 100644 --- a/backend/airweave/platform/converters/text_extractors/pptx.py +++ b/backend/airweave/platform/converters/text_extractors/pptx.py @@ -11,7 +11,7 @@ from typing import Any, Optional from airweave.core.logging import logger -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # Minimum total characters to consider the extraction successful. MIN_TOTAL_CHARS = 50 diff --git a/backend/airweave/platform/converters/txt_converter.py b/backend/airweave/platform/converters/txt_converter.py index 4cf42c24f..84062f029 100644 --- a/backend/airweave/platform/converters/txt_converter.py +++ b/backend/airweave/platform/converters/txt_converter.py @@ -10,8 +10,8 @@ from airweave.core.logging import logger from airweave.platform.converters._base import BaseTextConverter -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError class TxtConverter(BaseTextConverter): diff --git a/backend/airweave/platform/converters/web_converter.py b/backend/airweave/platform/converters/web_converter.py index 126b2cdb4..451f63201 100644 --- a/backend/airweave/platform/converters/web_converter.py +++ b/backend/airweave/platform/converters/web_converter.py @@ -10,7 +10,7 @@ from airweave.core.logging import logger from airweave.platform.converters._base import BaseTextConverter from airweave.platform.rate_limiters import FirecrawlRateLimiter -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # ==================== CONFIGURATION ==================== diff --git a/backend/airweave/platform/converters/xlsx_converter.py b/backend/airweave/platform/converters/xlsx_converter.py index 8ad649776..8be6a3f34 100644 --- a/backend/airweave/platform/converters/xlsx_converter.py +++ b/backend/airweave/platform/converters/xlsx_converter.py @@ -5,8 +5,8 @@ from airweave.core.logging import logger from airweave.platform.converters._base import BaseTextConverter -from airweave.platform.sync.async_helpers import run_in_thread_pool -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError class XlsxConverter(BaseTextConverter): diff --git a/backend/airweave/platform/ocr/mistral/compressor.py b/backend/airweave/platform/ocr/mistral/compressor.py index c760ef05d..d3c3725a5 100644 --- a/backend/airweave/platform/ocr/mistral/compressor.py +++ b/backend/airweave/platform/ocr/mistral/compressor.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from PIL import Image from airweave.platform.ocr.mistral.models import CompressionResult -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError # Quality levels to try, from highest to lowest. _QUALITY_STEPS = range(85, 19, -10) diff --git a/backend/airweave/platform/ocr/mistral/converter.py b/backend/airweave/platform/ocr/mistral/converter.py index ac379ebee..156ce0316 100644 --- a/backend/airweave/platform/ocr/mistral/converter.py +++ b/backend/airweave/platform/ocr/mistral/converter.py @@ -41,7 +41,7 @@ PdfSplitter, RecursiveSplitter, ) -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError # Mistral upload limit. MAX_FILE_SIZE_BYTES = 50_000_000 # 50 MB diff --git a/backend/airweave/platform/ocr/mistral/ocr_client.py b/backend/airweave/platform/ocr/mistral/ocr_client.py index ff49d6c9a..ccd17f5cb 100644 --- a/backend/airweave/platform/ocr/mistral/ocr_client.py +++ b/backend/airweave/platform/ocr/mistral/ocr_client.py @@ -25,7 +25,7 @@ OcrResult, ) from airweave.platform.rate_limiters import MistralRateLimiter -from airweave.platform.sync.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.exceptions import SyncFailureError # --------------------------------------------------------------------------- # Retry configuration diff --git a/backend/airweave/platform/ocr/mistral/splitters.py b/backend/airweave/platform/ocr/mistral/splitters.py index 5571b2e97..b0bc1b37a 100644 --- a/backend/airweave/platform/ocr/mistral/splitters.py +++ b/backend/airweave/platform/ocr/mistral/splitters.py @@ -19,7 +19,7 @@ import aiofiles.os from airweave.core.logging import logger -from airweave.platform.sync.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError class RecursiveSplitter(ABC): diff --git a/backend/airweave/platform/sources/google_drive.py b/backend/airweave/platform/sources/google_drive.py index 750f6cab3..da86abb2b 100644 --- a/backend/airweave/platform/sources/google_drive.py +++ b/backend/airweave/platform/sources/google_drive.py @@ -1623,6 +1623,6 @@ async def _worker_match_user( except Exception as e: self.logger.error(f"Critical error in generate_entities: {str(e)}") # Re-raise as SyncFailureError to explicitly fail the sync - from airweave.platform.sync.exceptions import SyncFailureError + from airweave.domains.sync_pipeline.exceptions import SyncFailureError raise SyncFailureError(f"Google Drive sync failed: {str(e)}") from e diff --git a/backend/airweave/platform/sources/sharepoint2019v2/builders.py b/backend/airweave/platform/sources/sharepoint2019v2/builders.py index df815cb7e..fff472a01 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/builders.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/builders.py @@ -26,7 +26,7 @@ clean_role_assignments, extract_access_control, ) -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: diff --git a/backend/airweave/platform/sources/sharepoint2019v2/source.py b/backend/airweave/platform/sources/sharepoint2019v2/source.py index 5d76a668e..6a978b9c7 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/source.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/source.py @@ -49,7 +49,7 @@ ) from airweave.platform.sources.sharepoint2019v2.client import SharePointClient from airweave.platform.storage import FileSkippedException -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.schemas.source_connection import AuthenticationMethod # Maximum concurrent file downloads diff --git a/backend/airweave/platform/sources/sharepoint_online/builders.py b/backend/airweave/platform/sources/sharepoint_online/builders.py index 9bfd2f87d..b88514944 100644 --- a/backend/airweave/platform/sources/sharepoint_online/builders.py +++ b/backend/airweave/platform/sources/sharepoint_online/builders.py @@ -17,7 +17,7 @@ SharePointOnlineSiteEntity, ) from airweave.platform.sources.sharepoint_online.acl import extract_access_control -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index b1c4264a2..5363e5614 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -46,7 +46,7 @@ from airweave.platform.sources.sharepoint_online.client import GraphClient from airweave.platform.sources.sharepoint_online.graph_groups import EntraGroupExpander from airweave.platform.storage import FileSkippedException -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.schemas.source_connection import AuthenticationMethod, OAuthType MAX_CONCURRENT_FILE_DOWNLOADS = 10 diff --git a/backend/airweave/platform/sources/slack.py b/backend/airweave/platform/sources/slack.py index 1daee7a35..b73181816 100644 --- a/backend/airweave/platform/sources/slack.py +++ b/backend/airweave/platform/sources/slack.py @@ -18,7 +18,7 @@ retry_if_rate_limit_or_timeout, wait_rate_limit_with_backoff, ) -from airweave.platform.sync.pipeline.text_builder import text_builder +from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder from airweave.schemas.source_connection import AuthenticationMethod, OAuthType diff --git a/backend/airweave/platform/storage/file_service.py b/backend/airweave/platform/storage/file_service.py index 2dbb77aff..c0d874ae7 100644 --- a/backend/airweave/platform/storage/file_service.py +++ b/backend/airweave/platform/storage/file_service.py @@ -24,7 +24,7 @@ ) from airweave.platform.storage.exceptions import FileSkippedException from airweave.platform.storage.paths import paths -from airweave.platform.sync.file_types import SUPPORTED_FILE_EXTENSIONS +from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS if TYPE_CHECKING: from airweave.platform.storage.protocol import StorageBackend diff --git a/backend/airweave/platform/sync/actions/__init__.py b/backend/airweave/platform/sync/actions/__init__.py deleted file mode 100644 index fef9b0a07..000000000 --- a/backend/airweave/platform/sync/actions/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Actions module for sync pipelines. - -Organized by domain: -- entity/: Entity action types, resolver, dispatcher, builder -- access_control/: Access control action types, resolver, dispatcher - -Each domain has its own types, resolver, and dispatcher tailored to its needs. - -Entity re-exports are lazy to avoid circular imports with domains/sync_pipeline. -""" - -from airweave.platform.sync.actions.access_control import ( - ACActionDispatcher, - ACActionResolver, - ACDeleteAction, - ACInsertAction, - ACKeepAction, - ACUpdateAction, -) - - -def __getattr__(name: str): - """Lazy re-exports for entity action symbols.""" - from airweave.platform.sync.actions import entity as _entity_pkg - - if name in _entity_pkg.__all__: - return getattr(_entity_pkg, name) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = [ - # Access control types - "ACDeleteAction", - "ACInsertAction", - "ACKeepAction", - "ACUpdateAction", - # Access control resolver and dispatcher - "ACActionResolver", - "ACActionDispatcher", - # Entity types (lazy) - "EntityActionBatch", - "EntityDeleteAction", - "EntityInsertAction", - "EntityKeepAction", - "EntityUpdateAction", - # Entity resolver and dispatcher (lazy) - "EntityActionResolver", - "EntityActionDispatcher", - "EntityDispatcherBuilder", -] diff --git a/backend/airweave/platform/sync/actions/access_control/__init__.py b/backend/airweave/platform/sync/actions/access_control/__init__.py deleted file mode 100644 index 779bc82ad..000000000 --- a/backend/airweave/platform/sync/actions/access_control/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Access control action types, resolver, and dispatcher. - -Access control membership action pipeline for sync operations. -""" - -from airweave.platform.sync.actions.access_control.dispatcher import ACActionDispatcher -from airweave.platform.sync.actions.access_control.resolver import ACActionResolver -from airweave.platform.sync.actions.access_control.types import ( - ACActionBatch, - ACDeleteAction, - ACInsertAction, - ACKeepAction, - ACUpdateAction, - ACUpsertAction, -) - -__all__ = [ - # Types - "ACActionBatch", - "ACDeleteAction", - "ACInsertAction", - "ACKeepAction", - "ACUpdateAction", - "ACUpsertAction", - # Resolver and Dispatcher - "ACActionResolver", - "ACActionDispatcher", -] diff --git a/backend/airweave/platform/sync/actions/entity/__init__.py b/backend/airweave/platform/sync/actions/entity/__init__.py deleted file mode 100644 index ec47afb06..000000000 --- a/backend/airweave/platform/sync/actions/entity/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Entity action types, resolver, and dispatcher. - -Entity-specific action pipeline for sync operations. - -Re-exports are lazy to avoid circular imports: entity_action_resolver imports -from .types, which triggers this __init__. Using __getattr__ breaks the cycle. -""" - - -def __getattr__(name: str): - """Lazy re-exports to avoid circular imports.""" - _map = { - "EntityActionDispatcher": ( - "airweave.domains.sync_pipeline.entity_action_dispatcher", - "EntityActionDispatcher", - ), - "EntityActionResolver": ( - "airweave.domains.sync_pipeline.entity_action_resolver", - "EntityActionResolver", - ), - "EntityDispatcherBuilder": ( - "airweave.platform.sync.actions.entity.builder", - "EntityDispatcherBuilder", - ), - "EntityActionBatch": ( - "airweave.platform.sync.actions.entity.types", - "EntityActionBatch", - ), - "EntityDeleteAction": ( - "airweave.platform.sync.actions.entity.types", - "EntityDeleteAction", - ), - "EntityInsertAction": ( - "airweave.platform.sync.actions.entity.types", - "EntityInsertAction", - ), - "EntityKeepAction": ( - "airweave.platform.sync.actions.entity.types", - "EntityKeepAction", - ), - "EntityUpdateAction": ( - "airweave.platform.sync.actions.entity.types", - "EntityUpdateAction", - ), - } - if name in _map: - import importlib - - module_path, attr = _map[name] - return getattr(importlib.import_module(module_path), attr) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = [ - "EntityActionBatch", - "EntityDeleteAction", - "EntityInsertAction", - "EntityKeepAction", - "EntityUpdateAction", - "EntityActionResolver", - "EntityActionDispatcher", - "EntityDispatcherBuilder", -] diff --git a/backend/airweave/platform/sync/arf/service.py b/backend/airweave/platform/sync/arf/service.py index 491620b5d..7fab1e117 100644 --- a/backend/airweave/platform/sync/arf/service.py +++ b/backend/airweave/platform/sync/arf/service.py @@ -43,8 +43,8 @@ from airweave.platform.sync.arf.schema import SyncManifest if TYPE_CHECKING: - from airweave.platform.contexts import SyncContext - from airweave.platform.contexts.runtime import SyncRuntime + from airweave.domains.sync_pipeline.contexts import SyncContext + from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.platform.entities._base import BaseEntity diff --git a/backend/airweave/platform/sync/config/__init__.py b/backend/airweave/platform/sync/config/__init__.py deleted file mode 100644 index 96e50c006..000000000 --- a/backend/airweave/platform/sync/config/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Sync configuration module with layered overrides. - -Resolution order (lowest → highest priority): -1. Schema defaults (SyncConfig field defaults) -2. Environment (SYNC_CONFIG__* env vars) -3. Collection (collection.sync_config) -4. Sync (sync.sync_config) -5. SyncJob (sync_job.sync_config) - -Env vars use double underscore delimiter: - SYNC_CONFIG__DESTINATIONS__SKIP_QDRANT=true - SYNC_CONFIG__HANDLERS__ENABLE_VECTOR_HANDLERS=false - SYNC_CONFIG__CURSOR__SKIP_LOAD=true - SYNC_CONFIG__BEHAVIOR__REPLAY_FROM_ARF=true - -Usage: - from airweave.platform.sync.config import SyncConfig, SyncConfigBuilder - - # Build config with all layers (sync_config is already typed as SyncConfig) - config = SyncConfigBuilder.build( - collection_overrides=collection.sync_config, - sync_overrides=sync.sync_config, - job_overrides=sync_job.sync_config, - ) - - # Use preset - config = SyncConfig.arf_capture_only() - - # Direct config with env var loading - config = SyncConfig() # Reads SYNC_CONFIG__* env vars automatically -""" - -from airweave.platform.sync.config.base import ( - BehaviorConfig, - CursorConfig, - DestinationConfig, - HandlerConfig, - SyncConfig, -) -from airweave.platform.sync.config.builder import SyncConfigBuilder - -# Backwards compatibility alias - TODO: remove after migration -SyncExecutionConfig = SyncConfig - -__all__ = [ - # Main config - "SyncConfig", - # Sub-configs - "DestinationConfig", - "HandlerConfig", - "CursorConfig", - "BehaviorConfig", - # Builder - "SyncConfigBuilder", - # Backwards compatibility - "SyncExecutionConfig", -] diff --git a/backend/airweave/platform/sync/handlers/__init__.py b/backend/airweave/platform/sync/handlers/__init__.py deleted file mode 100644 index 5bb02b05d..000000000 --- a/backend/airweave/platform/sync/handlers/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Handlers module for sync pipeline. - -Contains handlers that execute resolved actions. - -Generic Protocol: - ActionHandler[T, B] - parameterized by payload type T and batch type B - -Type Aliases: - EntityActionHandler = ActionHandler[BaseEntity, EntityActionBatch] - -Entity Handlers: -- DestinationHandler: Chunks/embeds entities and inserts into vector destinations -- ArfHandler: Raw entity storage for audit/replay (ARF = Airweave Raw Format) -- EntityPostgresHandler: Entity metadata persistence (runs last) - -Architecture: - All handlers implement ActionHandler[T, B] with their specific types. - Entity handlers use T=BaseEntity, B=EntityActionBatch. - The dispatchers call handlers concurrently for their respective sync types. -""" - -# Handlers -from .access_control_postgres import ACPostgresHandler -from .arf import ArfHandler -from .destination import DestinationHandler -from .entity_postgres import EntityPostgresHandler - -# Protocol and type aliases -from .protocol import ACActionHandler, EntityActionHandler - -__all__ = [ - # Protocol and type aliases - "ACActionHandler", - "EntityActionHandler", - # Entity handlers - "ACPostgresHandler", - "ArfHandler", - "DestinationHandler", - "EntityPostgresHandler", -] diff --git a/backend/airweave/platform/sync/pipeline/__init__.py b/backend/airweave/platform/sync/pipeline/__init__.py deleted file mode 100644 index ae867f686..000000000 --- a/backend/airweave/platform/sync/pipeline/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Pipeline components for entity and ACL processing. - -This module contains the event-driven pipeline architecture: - -Core Components: -- EntityTracker: Central entity state tracking (dedup + counts + pubsub) -- ACLMembershipTracker: ACL membership tracking (dedup + orphan detection) - -Processing Helpers: -- HashComputer: Computes content hashes -- TextualRepresentationBuilder: Builds textual representations -- CleanupService: Handles orphan and temp file cleanup -""" - -# Core components -# Processing helpers -from airweave.platform.sync.pipeline.acl_membership_tracker import ( - ACLMembershipTracker, - ACLSyncStats, -) -from airweave.platform.sync.pipeline.cleanup_service import cleanup_service -from airweave.platform.sync.pipeline.entity_tracker import EntityTracker -from airweave.platform.sync.pipeline.hash_computer import hash_computer -from airweave.platform.sync.pipeline.text_builder import ( - TextualRepresentationBuilder, - text_builder, -) - -__all__ = [ - # Core components - "EntityTracker", - "ACLMembershipTracker", - "ACLSyncStats", - # Processing helpers - "cleanup_service", - "hash_computer", - "TextualRepresentationBuilder", - "text_builder", -] diff --git a/backend/airweave/platform/sync/processors/__init__.py b/backend/airweave/platform/sync/processors/__init__.py deleted file mode 100644 index 7e8a0c353..000000000 --- a/backend/airweave/platform/sync/processors/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Content processors for entity preparation. - -Available Processors: -- ChunkEmbedProcessor: Unified processor for chunk-as-document model (Qdrant, Vespa) - - With sparse=True: dense + sparse embeddings for hybrid search (Qdrant) - - With sparse=False: dense only, BM25 computed server-side (Vespa) -""" - -from .chunk_embed import ChunkEmbedProcessor - -__all__ = [ - "ChunkEmbedProcessor", -] diff --git a/backend/airweave/platform/sync/subscribers/__init__.py b/backend/airweave/platform/sync/subscribers/__init__.py deleted file mode 100644 index 0f4346f15..000000000 --- a/backend/airweave/platform/sync/subscribers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Global event subscribers for entity batch events with per-sync session registries.""" - -from airweave.platform.sync.subscribers.progress_relay import SyncProgressRelay - -__all__ = [ - "SyncProgressRelay", -] diff --git a/backend/airweave/platform/sync/subscribers/tests/__init__.py b/backend/airweave/platform/sync/subscribers/tests/__init__.py deleted file mode 100644 index b59009c0a..000000000 --- a/backend/airweave/platform/sync/subscribers/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for sync subscriber modules.""" diff --git a/backend/airweave/platform/sync/web_fetcher.py b/backend/airweave/platform/sync/web_fetcher.py index 00c9e7313..1b36d4172 100644 --- a/backend/airweave/platform/sync/web_fetcher.py +++ b/backend/airweave/platform/sync/web_fetcher.py @@ -14,7 +14,7 @@ from airweave.core.logging import ContextualLogger from airweave.platform.entities._base import WebEntity from airweave.platform.entities.web import WebFileEntity -from airweave.platform.sync.async_helpers import run_in_thread_pool +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool # Improved connection management _shared_firecrawl_client = None diff --git a/backend/airweave/platform/temporal/activities/sync.py b/backend/airweave/platform/temporal/activities/sync.py index da2f93372..168bfb993 100644 --- a/backend/airweave/platform/temporal/activities/sync.py +++ b/backend/airweave/platform/temporal/activities/sync.py @@ -420,7 +420,7 @@ async def _run_sync_task( from airweave import crud from airweave.core.exceptions import NotFoundException from airweave.db.session import get_db_context - from airweave.platform.sync.config import SyncConfig + from airweave.domains.sync_pipeline.config import SyncConfig execution_config = None try: diff --git a/backend/airweave/platform/temporal/worker/control_server.py b/backend/airweave/platform/temporal/worker/control_server.py index db35fb3a4..3f0414bcf 100644 --- a/backend/airweave/platform/temporal/worker/control_server.py +++ b/backend/airweave/platform/temporal/worker/control_server.py @@ -25,7 +25,7 @@ WorkerMetrics, WorkerMetricsRegistryProtocol, ) -from airweave.platform.sync.async_helpers import get_active_thread_count +from airweave.domains.sync_pipeline.async_helpers import get_active_thread_count from airweave.platform.temporal.worker_metrics_snapshot import ( ConnectorSnapshot, WorkerMetricsSnapshot, diff --git a/backend/airweave/schemas/collection.py b/backend/airweave/schemas/collection.py index bce69665f..eeda80469 100644 --- a/backend/airweave/schemas/collection.py +++ b/backend/airweave/schemas/collection.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator, model_validator from airweave.core.shared_models import CollectionStatus -from airweave.platform.sync.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.base import SyncConfig def generate_readable_id(name: str) -> str: diff --git a/backend/airweave/schemas/sync.py b/backend/airweave/schemas/sync.py index 657c5c3b9..8738aa362 100644 --- a/backend/airweave/schemas/sync.py +++ b/backend/airweave/schemas/sync.py @@ -9,7 +9,7 @@ from airweave import schemas from airweave.core.shared_models import SyncStatus -from airweave.platform.sync.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.base import SyncConfig class SyncBase(BaseModel): diff --git a/backend/airweave/schemas/sync_job.py b/backend/airweave/schemas/sync_job.py index 432e72335..ebbadfcdf 100644 --- a/backend/airweave/schemas/sync_job.py +++ b/backend/airweave/schemas/sync_job.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field from airweave.models.sync_job import SyncJobStatus -from airweave.platform.sync.config.base import SyncConfig +from airweave.domains.sync_pipeline.config.base import SyncConfig class SyncJobBase(BaseModel): diff --git a/backend/tests/unit/platform/converters/test_txt_converter.py b/backend/tests/unit/platform/converters/test_txt_converter.py index f8ed1e069..ceb6a85f0 100644 --- a/backend/tests/unit/platform/converters/test_txt_converter.py +++ b/backend/tests/unit/platform/converters/test_txt_converter.py @@ -6,7 +6,7 @@ from pathlib import Path from airweave.platform.converters.txt_converter import TxtConverter -from airweave.platform.sync.exceptions import EntityProcessingError +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError @pytest.fixture From 09fd37bf437b80451b2302ab2eeebbfb13e50923 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 17:31:05 -0700 Subject: [PATCH 06/13] style: apply ruff auto-format to moved modules --- backend/airweave/api/v1/endpoints/admin.py | 2 +- backend/airweave/core/container/factory.py | 4 +- .../airweave/domains/browse_tree/service.py | 2 +- .../airweave/domains/entities/protocols.py | 18 ++---- .../access_control_dispatcher.py | 2 +- .../sync_pipeline/access_control_pipeline.py | 6 +- .../sync_pipeline/access_control_resolver.py | 2 +- .../sync_pipeline/builders/destinations.py | 2 +- .../domains/sync_pipeline/builders/source.py | 4 +- .../domains/sync_pipeline/builders/sync.py | 2 +- .../domains/sync_pipeline/contexts/runtime.py | 4 +- .../domains/sync_pipeline/contexts/source.py | 2 +- .../domains/sync_pipeline/contexts/sync.py | 2 +- .../sync_pipeline/entity_action_dispatcher.py | 2 +- .../sync_pipeline/entity_action_resolver.py | 4 +- .../entity_dispatcher_builder.py | 4 +- .../domains/sync_pipeline/entity_pipeline.py | 12 ++-- .../airweave/domains/sync_pipeline/factory.py | 14 ++--- .../sync_pipeline/fakes/entity_repository.py | 12 +--- .../handlers/access_control_postgres.py | 4 +- .../domains/sync_pipeline/handlers/arf.py | 4 +- .../sync_pipeline/handlers/destination.py | 8 +-- .../sync_pipeline/handlers/entity_postgres.py | 4 +- .../domains/sync_pipeline/orchestrator.py | 12 ++-- .../sync_pipeline/pipeline/cleanup_service.py | 2 +- .../sync_pipeline/pipeline/hash_computer.py | 2 +- .../sync_pipeline/pipeline/text_builder.py | 2 +- .../sync_pipeline/processors/chunk_embed.py | 2 +- .../domains/sync_pipeline/protocols.py | 4 +- .../tests/test_acl_membership_tracker.py | 1 - .../tests/test_cleanup_service.py | 60 +++++++++---------- .../tests/test_config_builder.py | 2 +- backend/airweave/domains/syncs/protocols.py | 4 +- backend/airweave/domains/syncs/service.py | 2 +- .../domains/syncs/sync_job_service.py | 2 +- backend/airweave/platform/chunkers/code.py | 4 +- .../airweave/platform/chunkers/semantic.py | 4 +- .../platform/converters/html_converter.py | 2 +- .../platform/converters/txt_converter.py | 2 +- .../platform/converters/web_converter.py | 2 +- .../platform/converters/xlsx_converter.py | 2 +- .../platform/ocr/mistral/compressor.py | 2 +- .../platform/ocr/mistral/converter.py | 2 +- .../platform/ocr/mistral/ocr_client.py | 2 +- .../sources/sharepoint2019v2/builders.py | 2 +- .../sources/sharepoint2019v2/source.py | 2 +- .../sources/sharepoint_online/builders.py | 2 +- .../sources/sharepoint_online/source.py | 2 +- backend/airweave/platform/sources/slack.py | 2 +- .../airweave/platform/storage/file_service.py | 2 +- backend/airweave/platform/sync/web_fetcher.py | 2 +- backend/airweave/schemas/sync_job.py | 2 +- .../platform/converters/test_txt_converter.py | 1 - 53 files changed, 119 insertions(+), 135 deletions(-) diff --git a/backend/airweave/api/v1/endpoints/admin.py b/backend/airweave/api/v1/endpoints/admin.py index d7c0b2d94..e3af87a7a 100644 --- a/backend/airweave/api/v1/endpoints/admin.py +++ b/backend/airweave/api/v1/endpoints/admin.py @@ -43,13 +43,13 @@ from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.organizations.logic import generate_org_name from airweave.domains.source_connections.protocols import SourceConnectionServiceProtocol +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.syncs.protocols import SyncJobServiceProtocol from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol from airweave.domains.usage.repository import UsageRepository from airweave.models.organization import Organization from airweave.models.organization_billing import OrganizationBilling from airweave.models.user_organization import UserOrganization -from airweave.domains.sync_pipeline.config import SyncConfig from airweave.schemas.organization_billing import BillingPlan, BillingStatus router = TrailingSlashRouter() diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index f3bfa0ee2..1890348b3 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -44,6 +44,7 @@ from airweave.core.protocols.webhooks import WebhookPublisher from airweave.core.redis_client import redis_client from airweave.db.session import health_check_engine +from airweave.domains.access_control.repository import AccessControlMembershipRepository from airweave.domains.auth_provider.registry import AuthProviderRegistry from airweave.domains.auth_provider.service import AuthProviderService from airweave.domains.browse_tree.repository import NodeSelectionRepository @@ -90,9 +91,9 @@ from airweave.domains.sources.registry import SourceRegistry from airweave.domains.sources.service import SourceService from airweave.domains.sources.validation import SourceValidationService -from airweave.domains.access_control.repository import AccessControlMembershipRepository from airweave.domains.sync_pipeline.factory import SyncFactory from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor +from airweave.domains.sync_pipeline.subscribers.progress_relay import SyncProgressRelay from airweave.domains.syncs.service import SyncService from airweave.domains.syncs.sync_cursor_repository import SyncCursorRepository from airweave.domains.syncs.sync_job_repository import SyncJobRepository @@ -110,7 +111,6 @@ from airweave.domains.webhooks.service import WebhookServiceImpl from airweave.domains.webhooks.subscribers import WebhookEventSubscriber from airweave.platform.auth.settings import integration_settings -from airweave.domains.sync_pipeline.subscribers.progress_relay import SyncProgressRelay from airweave.platform.temporal.client import TemporalClient diff --git a/backend/airweave/domains/browse_tree/service.py b/backend/airweave/domains/browse_tree/service.py index 708c1d30e..ab16b593f 100644 --- a/backend/airweave/domains/browse_tree/service.py +++ b/backend/airweave/domains/browse_tree/service.py @@ -23,9 +23,9 @@ from airweave.domains.connections.protocols import ConnectionRepositoryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol from airweave.domains.sources.protocols import SourceLifecycleServiceProtocol +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.syncs.protocols import SyncJobRepositoryProtocol, SyncRepositoryProtocol from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol -from airweave.domains.sync_pipeline.config import SyncConfig from airweave.schemas.sync_job import SyncJobCreate, SyncJobStatus diff --git a/backend/airweave/domains/entities/protocols.py b/backend/airweave/domains/entities/protocols.py index e6f71d857..94e66e725 100644 --- a/backend/airweave/domains/entities/protocols.py +++ b/backend/airweave/domains/entities/protocols.py @@ -32,8 +32,7 @@ async def get_counts_per_sync_and_type( class EntityRepositoryProtocol(Protocol): """Entity data access used by the sync pipeline.""" - async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: - ... + async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: ... async def bulk_get_by_entity_sync_and_definition( self, @@ -41,8 +40,7 @@ async def bulk_get_by_entity_sync_and_definition( *, sync_id: UUID, entity_requests: list[Tuple[str, str]], - ) -> Dict[Tuple[str, str], Entity]: - ... + ) -> Dict[Tuple[str, str], Entity]: ... async def bulk_create( self, @@ -50,16 +48,14 @@ async def bulk_create( *, objs: list, ctx: Any, - ) -> List[Entity]: - ... + ) -> List[Entity]: ... async def bulk_update_hash( self, db: AsyncSession, *, rows: List[Tuple[UUID, str]], - ) -> None: - ... + ) -> None: ... async def bulk_remove( self, @@ -67,8 +63,7 @@ async def bulk_remove( *, ids: List[UUID], ctx: Any, - ) -> List[Entity]: - ... + ) -> List[Entity]: ... async def bulk_get_by_entity_and_sync( self, @@ -76,5 +71,4 @@ async def bulk_get_by_entity_and_sync( *, sync_id: UUID, entity_ids: List[str], - ) -> Dict[str, Entity]: - ... + ) -> Dict[str, Entity]: ... diff --git a/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py b/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py index 45e6a3889..5e31e0a94 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py @@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, List -from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py index 41c00ebd8..f91853509 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py @@ -13,13 +13,13 @@ from airweave.db.session import get_db_context from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol +from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher +from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver +from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker from airweave.platform.access_control.schemas import ( ACLChangeType, MembershipTuple, ) -from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher -from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver -from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker from airweave.platform.utils.error_utils import get_error_message if TYPE_CHECKING: diff --git a/backend/airweave/domains/sync_pipeline/access_control_resolver.py b/backend/airweave/domains/sync_pipeline/access_control_resolver.py index 2fe41b1b3..744d7cab7 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_resolver.py +++ b/backend/airweave/domains/sync_pipeline/access_control_resolver.py @@ -6,11 +6,11 @@ from typing import TYPE_CHECKING, List -from airweave.platform.access_control.schemas import MembershipTuple from airweave.domains.sync_pipeline.types.access_control_actions import ( ACActionBatch, ACUpsertAction, ) +from airweave.platform.access_control.schemas import MembershipTuple if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/builders/destinations.py b/backend/airweave/domains/sync_pipeline/builders/destinations.py index f4f387f9b..5b076d9bd 100644 --- a/backend/airweave/domains/sync_pipeline/builders/destinations.py +++ b/backend/airweave/domains/sync_pipeline/builders/destinations.py @@ -13,10 +13,10 @@ from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID from airweave.core.context import BaseContext from airweave.core.logging import ContextualLogger +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.platform.destinations._base import BaseDestination from airweave.platform.destinations.vespa import VespaDestination from airweave.platform.entities._base import BaseEntity -from airweave.domains.sync_pipeline.config import SyncConfig class DestinationsContextBuilder: diff --git a/backend/airweave/domains/sync_pipeline/builders/source.py b/backend/airweave/domains/sync_pipeline/builders/source.py index 51bd131e8..8cd94c84c 100644 --- a/backend/airweave/domains/sync_pipeline/builders/source.py +++ b/backend/airweave/domains/sync_pipeline/builders/source.py @@ -23,11 +23,11 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.domains.browse_tree.repository import NodeSelectionRepository from airweave.domains.browse_tree.types import NodeSelectionData +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.contexts.infra import InfraContext from airweave.domains.sync_pipeline.contexts.source import SourceContext -from airweave.platform.sources._base import BaseSource -from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.cursor import SyncCursor +from airweave.platform.sources._base import BaseSource class SourceContextBuilder: diff --git a/backend/airweave/domains/sync_pipeline/builders/sync.py b/backend/airweave/domains/sync_pipeline/builders/sync.py index 5662a715c..bc791d8be 100644 --- a/backend/airweave/domains/sync_pipeline/builders/sync.py +++ b/backend/airweave/domains/sync_pipeline/builders/sync.py @@ -12,8 +12,8 @@ from airweave import schemas from airweave.core.context import BaseContext from airweave.core.logging import ContextualLogger, LoggerConfigurator -from airweave.domains.sync_pipeline.contexts.sync import SyncContext from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.contexts.sync import SyncContext class SyncContextBuilder: diff --git a/backend/airweave/domains/sync_pipeline/contexts/runtime.py b/backend/airweave/domains/sync_pipeline/contexts/runtime.py index c0e3a15fd..f1dc6ed6a 100644 --- a/backend/airweave/domains/sync_pipeline/contexts/runtime.py +++ b/backend/airweave/domains/sync_pipeline/contexts/runtime.py @@ -12,10 +12,10 @@ if TYPE_CHECKING: from airweave.core.protocols.event_bus import EventBus from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol - from airweave.platform.destinations._base import BaseDestination - from airweave.platform.sources._base import BaseSource from airweave.domains.sync_pipeline.cursor import SyncCursor from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker + from airweave.platform.destinations._base import BaseDestination + from airweave.platform.sources._base import BaseSource @dataclass diff --git a/backend/airweave/domains/sync_pipeline/contexts/source.py b/backend/airweave/domains/sync_pipeline/contexts/source.py index 85eaa27b1..92ecbcbb2 100644 --- a/backend/airweave/domains/sync_pipeline/contexts/source.py +++ b/backend/airweave/domains/sync_pipeline/contexts/source.py @@ -8,8 +8,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from airweave.platform.sources._base import BaseSource from airweave.domains.sync_pipeline.cursor import SyncCursor + from airweave.platform.sources._base import BaseSource @dataclass diff --git a/backend/airweave/domains/sync_pipeline/contexts/sync.py b/backend/airweave/domains/sync_pipeline/contexts/sync.py index 03050a357..24327861d 100644 --- a/backend/airweave/domains/sync_pipeline/contexts/sync.py +++ b/backend/airweave/domains/sync_pipeline/contexts/sync.py @@ -6,8 +6,8 @@ from airweave import schemas from airweave.core.context import BaseContext -from airweave.platform.entities._base import BaseEntity from airweave.domains.sync_pipeline.config.base import SyncConfig +from airweave.platform.entities._base import BaseEntity @dataclass diff --git a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py index f320533eb..290423713 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py @@ -7,10 +7,10 @@ import asyncio from typing import TYPE_CHECKING, List -from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py index f9e5a8c4f..430dff8c8 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_resolver.py @@ -10,7 +10,7 @@ from airweave import models from airweave.db.session import get_db_context from airweave.domains.entities.protocols import EntityRepositoryProtocol -from airweave.platform.entities._base import BaseEntity, DeletionEntity +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, @@ -18,7 +18,7 @@ EntityKeepAction, EntityUpdateAction, ) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.platform.entities._base import BaseEntity, DeletionEntity if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py index 7ec330d4f..0e805e275 100644 --- a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py +++ b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py @@ -4,14 +4,14 @@ from airweave.core.logging import ContextualLogger from airweave.domains.entities.protocols import EntityRepositoryProtocol -from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher -from airweave.platform.destinations._base import BaseDestination from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher from airweave.domains.sync_pipeline.handlers.arf import ArfHandler from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol +from airweave.platform.destinations._base import BaseDestination class EntityDispatcherBuilder: diff --git a/backend/airweave/domains/sync_pipeline/entity_pipeline.py b/backend/airweave/domains/sync_pipeline/entity_pipeline.py index 8347e2e3b..117f6b9ed 100644 --- a/backend/airweave/domains/sync_pipeline/entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/entity_pipeline.py @@ -18,18 +18,18 @@ from airweave.core.events.sync import EntityBatchProcessedEvent, TypeActionCounts from airweave.core.shared_models import AirweaveFieldFlag from airweave.domains.entities.protocols import EntityRepositoryProtocol -from airweave.domains.sync_pipeline.protocols import ( - EntityActionDispatcherProtocol, - EntityActionResolverProtocol, -) from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime -from airweave.platform.entities._base import BaseEntity -from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.pipeline.cleanup_service import cleanup_service from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker from airweave.domains.sync_pipeline.pipeline.hash_computer import hash_computer +from airweave.domains.sync_pipeline.protocols import ( + EntityActionDispatcherProtocol, + EntityActionResolverProtocol, +) +from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch +from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: from airweave.core.protocols.event_bus import EventBus diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index 93ef08517..d6b933ff3 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -24,15 +24,14 @@ from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol -from airweave.domains.usage.protocols import UsageLimitCheckerProtocol +from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher +from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.builders import SyncContextBuilder from airweave.domains.sync_pipeline.builders.tracking import TrackingContextBuilder +from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime -from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline -from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher -from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.entity_dispatcher_builder import EntityDispatcherBuilder -from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder from airweave.domains.sync_pipeline.handlers import ACPostgresHandler from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker @@ -40,6 +39,7 @@ from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol from airweave.domains.sync_pipeline.stream import AsyncSourceStream from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool +from airweave.domains.usage.protocols import UsageLimitCheckerProtocol from .entity_action_resolver import EntityActionResolver from .entity_pipeline import EntityPipeline @@ -184,9 +184,7 @@ async def create_orchestrator( access_control_pipeline = AccessControlPipeline( resolver=ACActionResolver(), - dispatcher=ACActionDispatcher( - handlers=[ACPostgresHandler(acl_repo=self._acl_repo)] - ), + dispatcher=ACActionDispatcher(handlers=[ACPostgresHandler(acl_repo=self._acl_repo)]), tracker=ACLMembershipTracker( source_connection_id=sync_context.source_connection_id, organization_id=sync_context.organization_id, diff --git a/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py index dc86d5d55..351b346dc 100644 --- a/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py +++ b/backend/airweave/domains/sync_pipeline/fakes/entity_repository.py @@ -26,19 +26,13 @@ async def bulk_get_by_entity_sync_and_definition( ) -> Dict[Tuple[str, str], Entity]: return {} - async def bulk_create( - self, db: AsyncSession, *, objs: list, ctx: Any - ) -> List[Entity]: + async def bulk_create(self, db: AsyncSession, *, objs: list, ctx: Any) -> List[Entity]: return [] - async def bulk_update_hash( - self, db: AsyncSession, *, rows: List[Tuple[UUID, str]] - ) -> None: + async def bulk_update_hash(self, db: AsyncSession, *, rows: List[Tuple[UUID, str]]) -> None: pass - async def bulk_remove( - self, db: AsyncSession, *, ids: List[UUID], ctx: Any - ) -> List[Entity]: + async def bulk_remove(self, db: AsyncSession, *, ids: List[UUID], ctx: Any) -> List[Entity]: self._entities = [e for e in self._entities if e.id not in ids] return [] diff --git a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py index 63aa7c2d6..8fcbada88 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py @@ -8,6 +8,8 @@ from airweave.db.session import get_db_context from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler from airweave.domains.sync_pipeline.types.access_control_actions import ( ACActionBatch, ACDeleteAction, @@ -15,8 +17,6 @@ ACUpdateAction, ACUpsertAction, ) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/handlers/arf.py b/backend/airweave/domains/sync_pipeline/handlers/arf.py index aed7b8018..3a12503e6 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/arf.py +++ b/backend/airweave/domains/sync_pipeline/handlers/arf.py @@ -6,14 +6,14 @@ from typing import TYPE_CHECKING, List +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/handlers/destination.py b/backend/airweave/domains/sync_pipeline/handlers/destination.py index adba71d9a..846749493 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/destination.py +++ b/backend/airweave/domains/sync_pipeline/handlers/destination.py @@ -10,16 +10,16 @@ import httpcore import httpx -from airweave.platform.destinations._base import BaseDestination +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler -from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol +from airweave.platform.destinations._base import BaseDestination if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py index 04144d951..feb0584bb 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py @@ -12,14 +12,14 @@ from airweave import schemas from airweave.db.session import get_db_context from airweave.domains.entities.protocols import EntityRepositoryProtocol +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.types.entity_actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/orchestrator.py b/backend/airweave/domains/sync_pipeline/orchestrator.py index 127ff24d4..34a5ec57f 100644 --- a/backend/airweave/domains/sync_pipeline/orchestrator.py +++ b/backend/airweave/domains/sync_pipeline/orchestrator.py @@ -12,18 +12,18 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.core.sync_job_service import sync_job_service from airweave.db.session import get_db_context +from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.domains.sync_pipeline.contexts import SyncContext +from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.domains.sync_pipeline.stream import AsyncSourceStream +from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.domains.usage.exceptions import ( PaymentRequiredError, UsageLimitExceededError, ) from airweave.domains.usage.types import ActionType -from airweave.domains.sync_pipeline.contexts import SyncContext -from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime -from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError -from airweave.domains.sync_pipeline.stream import AsyncSourceStream -from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.platform.utils.error_utils import get_error_message diff --git a/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py b/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py index 9014b94e5..36b343f77 100644 --- a/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/cleanup_service.py @@ -3,8 +3,8 @@ import os from typing import TYPE_CHECKING, Any, Dict, List -from airweave.platform.entities._base import FileEntity from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.platform.entities._base import FileEntity if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py b/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py index c341a88cf..bf0f28509 100644 --- a/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/hash_computer.py @@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Any, List, Optional, Tuple from airweave.core.shared_models import AirweaveFieldFlag -from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py index c0d58285f..c6fe8d43a 100644 --- a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py @@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from airweave.core.shared_models import AirweaveFieldFlag -from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity, WebEntity from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS +from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity, WebEntity if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py index 2cfc9a3df..28b6f54d5 100644 --- a/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py @@ -14,10 +14,10 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Tuple -from airweave.platform.entities._base import BaseEntity, CodeFileEntity from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder from airweave.domains.sync_pipeline.processors.utils import filter_empty_representations +from airweave.platform.entities._base import BaseEntity, CodeFileEntity if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/protocols.py b/backend/airweave/domains/sync_pipeline/protocols.py index 6680dad02..87cb7c38a 100644 --- a/backend/airweave/domains/sync_pipeline/protocols.py +++ b/backend/airweave/domains/sync_pipeline/protocols.py @@ -7,14 +7,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas -from airweave.platform.entities._base import BaseEntity from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch +from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: from airweave.core.context import BaseContext + from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime - from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator diff --git a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py index 133def56a..a3a60008a 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py @@ -302,4 +302,3 @@ def test_full_lifecycle(self, tracker): # 5. Log summary tracker.log_summary() assert tracker.logger.info.called - diff --git a/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py index 74a9f687f..15da6f6e8 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py @@ -68,7 +68,7 @@ def create_file_entity_with_temp_file(temp_dir: str, filename: str) -> _TestFile """Create a FileEntity with an actual temp file on disk.""" file_path = os.path.join(temp_dir, filename) Path(file_path).write_text("test content") - + entity = _TestFileEntity( file_id=str(uuid4()), name=filename, @@ -84,20 +84,20 @@ async def test_cleanup_inserts(temp_dir, mock_sync_context): """Test that INSERT action files are cleaned up.""" # Create file entity with temp file entity = create_file_entity_with_temp_file(temp_dir, "insert.txt") - + partitions = { "inserts": [entity], "updates": [], "keeps": [], "deletes": [], } - + # Verify file exists before cleanup assert os.path.exists(entity.local_path) - + # Run cleanup await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) - + # Verify file was deleted assert not os.path.exists(entity.local_path) mock_sync_context.logger.debug.assert_called() @@ -107,14 +107,14 @@ async def test_cleanup_inserts(temp_dir, mock_sync_context): async def test_cleanup_updates(temp_dir, mock_sync_context): """Test that UPDATE action files are cleaned up.""" entity = create_file_entity_with_temp_file(temp_dir, "update.txt") - + partitions = { "inserts": [], "updates": [entity], "keeps": [], "deletes": [], } - + assert os.path.exists(entity.local_path) await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) assert not os.path.exists(entity.local_path) @@ -123,19 +123,19 @@ async def test_cleanup_updates(temp_dir, mock_sync_context): @pytest.mark.asyncio async def test_cleanup_keeps(temp_dir, mock_sync_context): """Test that KEEP action files are cleaned up. - + This is the critical bug fix - KEEP files (unchanged) were being downloaded and hashed but never cleaned up, causing disk buildup. """ entity = create_file_entity_with_temp_file(temp_dir, "keep.txt") - + partitions = { "inserts": [], "updates": [], "keeps": [entity], "deletes": [], } - + assert os.path.exists(entity.local_path) await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) assert not os.path.exists(entity.local_path), "KEEP files should be cleaned up!" @@ -147,21 +147,21 @@ async def test_cleanup_mixed_actions(temp_dir, mock_sync_context): insert_entity = create_file_entity_with_temp_file(temp_dir, "insert.txt") update_entity = create_file_entity_with_temp_file(temp_dir, "update.txt") keep_entity = create_file_entity_with_temp_file(temp_dir, "keep.txt") - + partitions = { "inserts": [insert_entity], "updates": [update_entity], "keeps": [keep_entity], "deletes": [], } - + # All files exist assert os.path.exists(insert_entity.local_path) assert os.path.exists(update_entity.local_path) assert os.path.exists(keep_entity.local_path) - + await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) - + # All files deleted assert not os.path.exists(insert_entity.local_path) assert not os.path.exists(update_entity.local_path) @@ -171,22 +171,22 @@ async def test_cleanup_mixed_actions(temp_dir, mock_sync_context): @pytest.mark.asyncio async def test_cleanup_ignores_deletes(temp_dir, mock_sync_context): """Test that DELETE actions are ignored (no files to clean). - + DeletionEntity is not a FileEntity, so it should be skipped. """ # Create a deletion entity (no file on disk) deletion_entity = _TestDeletionEntity(deletion_id=str(uuid4()), label="deleted-item", breadcrumbs=[]) - + partitions = { "inserts": [], "updates": [], "keeps": [], "deletes": [deletion_entity], } - + # Should not raise any errors await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) - + # No cleanup logged since no FileEntities assert not any( "Progressive cleanup: deleted" in str(call) @@ -200,16 +200,16 @@ async def test_cleanup_ignores_non_file_entities(temp_dir, mock_sync_context): # Mix FileEntity with non-FileEntity file_entity = create_file_entity_with_temp_file(temp_dir, "file.txt") non_file_entity = _TestNonFileEntity(test_id=str(uuid4()), name="non-file-entity", breadcrumbs=[]) - + partitions = { "inserts": [file_entity, non_file_entity], "updates": [], "keeps": [], "deletes": [], } - + await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) - + # Only FileEntity was cleaned assert not os.path.exists(file_entity.local_path) @@ -218,23 +218,23 @@ async def test_cleanup_ignores_non_file_entities(temp_dir, mock_sync_context): async def test_cleanup_raises_on_failed_deletion(temp_dir, mock_sync_context): """Test that cleanup raises SyncFailureError if file deletion fails.""" entity = create_file_entity_with_temp_file(temp_dir, "locked.txt") - + partitions = { "inserts": [entity], "updates": [], "keeps": [], "deletes": [], } - + # Mock os.remove to fail original_remove = os.remove - + def mock_remove(path): if "locked.txt" in path: # Simulate file still exists after removal attempt return original_remove(path) - + with pytest.raises(SyncFailureError, match="Failed to delete .* temp files"): with patch("os.remove", side_effect=mock_remove): await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) @@ -245,14 +245,14 @@ async def test_cleanup_raises_on_missing_local_path(mock_sync_context): """Test that FileEntity without local_path raises error.""" entity = _TestFileEntity(file_id=str(uuid4()), name="no-path.txt", breadcrumbs=[]) # No local_path set - programming error - + partitions = { "inserts": [entity], "updates": [], "keeps": [], "deletes": [], } - + with pytest.raises(SyncFailureError, match="has no local_path after processing"): await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) @@ -261,16 +261,16 @@ async def test_cleanup_raises_on_missing_local_path(mock_sync_context): async def test_cleanup_handles_already_deleted_files(temp_dir, mock_sync_context): """Test that cleanup handles files that were already deleted gracefully.""" entity = create_file_entity_with_temp_file(temp_dir, "already_deleted.txt") - + # Delete the file manually before cleanup os.remove(entity.local_path) - + partitions = { "inserts": [entity], "updates": [], "keeps": [], "deletes": [], } - + # Should not raise error await cleanup_service.cleanup_processed_files(partitions, mock_sync_context) diff --git a/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py index fc9d01261..28e24141d 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py @@ -148,4 +148,4 @@ def test_from_db_json(self): collection_overrides=SyncConfig(**db_json) ) assert config.handlers.enable_vector_handlers is False - assert config.behavior.skip_hash_comparison is True \ No newline at end of file + assert config.behavior.skip_hash_comparison is True diff --git a/backend/airweave/domains/syncs/protocols.py b/backend/airweave/domains/syncs/protocols.py index 927e255c1..f921b03a7 100644 --- a/backend/airweave/domains/syncs/protocols.py +++ b/backend/airweave/domains/syncs/protocols.py @@ -11,12 +11,12 @@ from airweave.core.shared_models import SyncJobStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.sources.types import SourceRegistryEntry +from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats from airweave.domains.syncs.types import SyncProvisionResult from airweave.models.sync import Sync from airweave.models.sync_cursor import SyncCursor from airweave.models.sync_job import SyncJob -from airweave.domains.sync_pipeline.config import SyncConfig -from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob from airweave.schemas.sync import SyncCreate, SyncUpdate from airweave.schemas.sync_job import SyncJobCreate, SyncJobUpdate diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index 9d6886163..c93812f26 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -10,9 +10,9 @@ from airweave.core.datetime_utils import utc_now_naive from airweave.core.shared_models import SyncJobStatus from airweave.db.session import get_db_context +from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol from airweave.domains.syncs.protocols import SyncJobServiceProtocol, SyncServiceProtocol -from airweave.domains.sync_pipeline.config import SyncConfig class SyncService(SyncServiceProtocol): diff --git a/backend/airweave/domains/syncs/sync_job_service.py b/backend/airweave/domains/syncs/sync_job_service.py index 33de58434..2437a2db0 100644 --- a/backend/airweave/domains/syncs/sync_job_service.py +++ b/backend/airweave/domains/syncs/sync_job_service.py @@ -13,9 +13,9 @@ from airweave.core.logging import logger from airweave.core.shared_models import SyncJobStatus from airweave.db.session import get_db_context +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats from airweave.domains.syncs.protocols import SyncJobRepositoryProtocol, SyncJobServiceProtocol from airweave.domains.syncs.types import StatsUpdate, TimestampUpdate -from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats class SyncJobService(SyncJobServiceProtocol): diff --git a/backend/airweave/platform/chunkers/code.py b/backend/airweave/platform/chunkers/code.py index aabb3b0df..dcc19c76c 100644 --- a/backend/airweave/platform/chunkers/code.py +++ b/backend/airweave/platform/chunkers/code.py @@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional from airweave.core.logging import logger -from airweave.platform.chunkers._base import BaseChunker -from airweave.platform.chunkers.tiktoken_compat import SafeEncoding from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.platform.chunkers._base import BaseChunker +from airweave.platform.chunkers.tiktoken_compat import SafeEncoding from airweave.platform.tokenizers import TikTokenTokenizer, get_tokenizer diff --git a/backend/airweave/platform/chunkers/semantic.py b/backend/airweave/platform/chunkers/semantic.py index f3b5386ea..31eea37b4 100644 --- a/backend/airweave/platform/chunkers/semantic.py +++ b/backend/airweave/platform/chunkers/semantic.py @@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional from airweave.core.logging import logger -from airweave.platform.chunkers._base import BaseChunker -from airweave.platform.chunkers.tiktoken_compat import SafeEncoding from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.platform.chunkers._base import BaseChunker +from airweave.platform.chunkers.tiktoken_compat import SafeEncoding from airweave.platform.tokenizers import TikTokenTokenizer, get_tokenizer diff --git a/backend/airweave/platform/converters/html_converter.py b/backend/airweave/platform/converters/html_converter.py index c26421a40..0be1c7eee 100644 --- a/backend/airweave/platform/converters/html_converter.py +++ b/backend/airweave/platform/converters/html_converter.py @@ -4,9 +4,9 @@ from typing import Dict, List from airweave.core.logging import logger -from airweave.platform.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError +from airweave.platform.converters._base import BaseTextConverter class HtmlConverter(BaseTextConverter): diff --git a/backend/airweave/platform/converters/txt_converter.py b/backend/airweave/platform/converters/txt_converter.py index 84062f029..f4da0cdd2 100644 --- a/backend/airweave/platform/converters/txt_converter.py +++ b/backend/airweave/platform/converters/txt_converter.py @@ -9,9 +9,9 @@ import aiofiles from airweave.core.logging import logger -from airweave.platform.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError +from airweave.platform.converters._base import BaseTextConverter class TxtConverter(BaseTextConverter): diff --git a/backend/airweave/platform/converters/web_converter.py b/backend/airweave/platform/converters/web_converter.py index 451f63201..d243a2536 100644 --- a/backend/airweave/platform/converters/web_converter.py +++ b/backend/airweave/platform/converters/web_converter.py @@ -8,9 +8,9 @@ from airweave.core.config import settings from airweave.core.logging import logger +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.converters._base import BaseTextConverter from airweave.platform.rate_limiters import FirecrawlRateLimiter -from airweave.domains.sync_pipeline.exceptions import SyncFailureError # ==================== CONFIGURATION ==================== diff --git a/backend/airweave/platform/converters/xlsx_converter.py b/backend/airweave/platform/converters/xlsx_converter.py index 8be6a3f34..ccc5208d4 100644 --- a/backend/airweave/platform/converters/xlsx_converter.py +++ b/backend/airweave/platform/converters/xlsx_converter.py @@ -4,9 +4,9 @@ from typing import Dict, List from airweave.core.logging import logger -from airweave.platform.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.platform.converters._base import BaseTextConverter class XlsxConverter(BaseTextConverter): diff --git a/backend/airweave/platform/ocr/mistral/compressor.py b/backend/airweave/platform/ocr/mistral/compressor.py index d3c3725a5..e07196972 100644 --- a/backend/airweave/platform/ocr/mistral/compressor.py +++ b/backend/airweave/platform/ocr/mistral/compressor.py @@ -15,8 +15,8 @@ if TYPE_CHECKING: from PIL import Image -from airweave.platform.ocr.mistral.models import CompressionResult from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError +from airweave.platform.ocr.mistral.models import CompressionResult # Quality levels to try, from highest to lowest. _QUALITY_STEPS = range(85, 19, -10) diff --git a/backend/airweave/platform/ocr/mistral/converter.py b/backend/airweave/platform/ocr/mistral/converter.py index 156ce0316..db997485a 100644 --- a/backend/airweave/platform/ocr/mistral/converter.py +++ b/backend/airweave/platform/ocr/mistral/converter.py @@ -25,6 +25,7 @@ import aiofiles.os from airweave.core.logging import logger +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError from airweave.platform.converters.text_extractors.pptx import extract_pptx_text from airweave.platform.ocr.mistral.compressor import compress_image from airweave.platform.ocr.mistral.models import ( @@ -41,7 +42,6 @@ PdfSplitter, RecursiveSplitter, ) -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError # Mistral upload limit. MAX_FILE_SIZE_BYTES = 50_000_000 # 50 MB diff --git a/backend/airweave/platform/ocr/mistral/ocr_client.py b/backend/airweave/platform/ocr/mistral/ocr_client.py index ccd17f5cb..93b3b641b 100644 --- a/backend/airweave/platform/ocr/mistral/ocr_client.py +++ b/backend/airweave/platform/ocr/mistral/ocr_client.py @@ -20,12 +20,12 @@ from airweave.core.config import settings from airweave.core.logging import logger +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.ocr.mistral.models import ( FileChunk, OcrResult, ) from airweave.platform.rate_limiters import MistralRateLimiter -from airweave.domains.sync_pipeline.exceptions import SyncFailureError # --------------------------------------------------------------------------- # Retry configuration diff --git a/backend/airweave/platform/sources/sharepoint2019v2/builders.py b/backend/airweave/platform/sources/sharepoint2019v2/builders.py index fff472a01..511cd997d 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/builders.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/builders.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional from urllib.parse import urlparse +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.platform.entities._base import AccessControl, Breadcrumb from airweave.platform.entities.sharepoint2019v2 import ( SharePoint2019V2FileEntity, @@ -26,7 +27,6 @@ clean_role_assignments, extract_access_control, ) -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: diff --git a/backend/airweave/platform/sources/sharepoint2019v2/source.py b/backend/airweave/platform/sources/sharepoint2019v2/source.py index 6a978b9c7..73be18992 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/source.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/source.py @@ -25,6 +25,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.platform.access_control.schemas import MembershipTuple from airweave.platform.configs.auth import SharePoint2019V2AuthConfig from airweave.platform.configs.config import SharePoint2019V2Config @@ -49,7 +50,6 @@ ) from airweave.platform.sources.sharepoint2019v2.client import SharePointClient from airweave.platform.storage import FileSkippedException -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.schemas.source_connection import AuthenticationMethod # Maximum concurrent file downloads diff --git a/backend/airweave/platform/sources/sharepoint_online/builders.py b/backend/airweave/platform/sources/sharepoint_online/builders.py index b88514944..d4d1cb59f 100644 --- a/backend/airweave/platform/sources/sharepoint_online/builders.py +++ b/backend/airweave/platform/sources/sharepoint_online/builders.py @@ -8,6 +8,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.platform.entities._base import Breadcrumb from airweave.platform.entities.sharepoint_online import ( SharePointOnlineDriveEntity, @@ -17,7 +18,6 @@ SharePointOnlineSiteEntity, ) from airweave.platform.sources.sharepoint_online.acl import extract_access_control -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index 5363e5614..5b389711f 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -28,6 +28,7 @@ import httpx from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.platform.access_control.schemas import MembershipTuple from airweave.platform.configs.config import SharePointOnlineConfig from airweave.platform.cursors.sharepoint_online import SharePointOnlineCursor @@ -46,7 +47,6 @@ from airweave.platform.sources.sharepoint_online.client import GraphClient from airweave.platform.sources.sharepoint_online.graph_groups import EntraGroupExpander from airweave.platform.storage import FileSkippedException -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.schemas.source_connection import AuthenticationMethod, OAuthType MAX_CONCURRENT_FILE_DOWNLOADS = 10 diff --git a/backend/airweave/platform/sources/slack.py b/backend/airweave/platform/sources/slack.py index b73181816..f46e2701d 100644 --- a/backend/airweave/platform/sources/slack.py +++ b/backend/airweave/platform/sources/slack.py @@ -8,6 +8,7 @@ from airweave.core.exceptions import TokenRefreshError from airweave.core.shared_models import RateLimitLevel +from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder from airweave.platform.configs.auth import SlackAuthConfig from airweave.platform.configs.config import SlackConfig from airweave.platform.decorators import source @@ -18,7 +19,6 @@ retry_if_rate_limit_or_timeout, wait_rate_limit_with_backoff, ) -from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder from airweave.schemas.source_connection import AuthenticationMethod, OAuthType diff --git a/backend/airweave/platform/storage/file_service.py b/backend/airweave/platform/storage/file_service.py index c0d874ae7..e854e6172 100644 --- a/backend/airweave/platform/storage/file_service.py +++ b/backend/airweave/platform/storage/file_service.py @@ -17,6 +17,7 @@ from tenacity import retry, stop_after_attempt from airweave.core.logging import ContextualLogger +from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS from airweave.platform.entities._base import FileEntity from airweave.platform.sources.retry_helpers import ( retry_if_rate_limit_or_timeout, @@ -24,7 +25,6 @@ ) from airweave.platform.storage.exceptions import FileSkippedException from airweave.platform.storage.paths import paths -from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS if TYPE_CHECKING: from airweave.platform.storage.protocol import StorageBackend diff --git a/backend/airweave/platform/sync/web_fetcher.py b/backend/airweave/platform/sync/web_fetcher.py index 1b36d4172..2e0377366 100644 --- a/backend/airweave/platform/sync/web_fetcher.py +++ b/backend/airweave/platform/sync/web_fetcher.py @@ -12,9 +12,9 @@ from airweave.core.config import settings from airweave.core.logging import ContextualLogger +from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.platform.entities._base import WebEntity from airweave.platform.entities.web import WebFileEntity -from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool # Improved connection management _shared_firecrawl_client = None diff --git a/backend/airweave/schemas/sync_job.py b/backend/airweave/schemas/sync_job.py index ebbadfcdf..4b1e5d06a 100644 --- a/backend/airweave/schemas/sync_job.py +++ b/backend/airweave/schemas/sync_job.py @@ -11,8 +11,8 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field -from airweave.models.sync_job import SyncJobStatus from airweave.domains.sync_pipeline.config.base import SyncConfig +from airweave.models.sync_job import SyncJobStatus class SyncJobBase(BaseModel): diff --git a/backend/tests/unit/platform/converters/test_txt_converter.py b/backend/tests/unit/platform/converters/test_txt_converter.py index ceb6a85f0..e7946bf14 100644 --- a/backend/tests/unit/platform/converters/test_txt_converter.py +++ b/backend/tests/unit/platform/converters/test_txt_converter.py @@ -205,4 +205,3 @@ async def test_convert_json_invalid_syntax(self, converter, temp_dir): assert file_path in result assert result[file_path] is None - From 22e3f6c1dd1cf7bd5a8e99628172371f7e25d1b2 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 19:20:15 -0700 Subject: [PATCH 07/13] test: update AccessBroker tests to use injected acl_repo --- .../platform/access_control/test_broker.py | 248 ++++++++---------- 1 file changed, 113 insertions(+), 135 deletions(-) diff --git a/backend/tests/unit/platform/access_control/test_broker.py b/backend/tests/unit/platform/access_control/test_broker.py index 36bb5cb39..602362851 100644 --- a/backend/tests/unit/platform/access_control/test_broker.py +++ b/backend/tests/unit/platform/access_control/test_broker.py @@ -11,8 +11,8 @@ @pytest.fixture def broker(): - """Create AccessBroker instance.""" - return AccessBroker() + """Create AccessBroker instance with mocked ACL repo.""" + return AccessBroker(acl_repo=MagicMock()) @pytest.fixture @@ -35,65 +35,61 @@ async def test_resolve_access_context_for_user_with_no_memberships( self, broker, mock_db, organization_id ): """Test resolution for user with no group memberships.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - mock_crud.access_control_membership.get_by_member = AsyncMock(return_value=[]) + broker._acl_repo.get_by_member = AsyncMock(return_value=[]) - result = await broker.resolve_access_context( - db=mock_db, user_principal="john@acme.com", organization_id=organization_id - ) + result = await broker.resolve_access_context( + db=mock_db, user_principal="john@acme.com", organization_id=organization_id + ) - assert isinstance(result, AccessContext) - assert result.user_principal == "john@acme.com" - assert result.user_principals == ["user:john@acme.com"] - assert result.group_principals == [] - assert len(result.all_principals) == 1 + assert isinstance(result, AccessContext) + assert result.user_principal == "john@acme.com" + assert result.user_principals == ["user:john@acme.com"] + assert result.group_principals == [] + assert len(result.all_principals) == 1 @pytest.mark.asyncio async def test_resolve_access_context_for_user_with_direct_groups( self, broker, mock_db, organization_id ): """Test resolution for user with direct group memberships.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - # Mock membership data - membership1 = MagicMock() - membership1.group_id = "sp:engineering" - membership1.group_name = "Engineering" + membership1 = MagicMock() + membership1.group_id = "sp:engineering" + membership1.group_name = "Engineering" - membership2 = MagicMock() - membership2.group_id = "ad:frontend" - membership2.group_name = "Frontend Team" + membership2 = MagicMock() + membership2.group_id = "ad:frontend" + membership2.group_name = "Frontend Team" - mock_crud.access_control_membership.get_by_member = AsyncMock(return_value=[membership1, membership2]) + broker._acl_repo.get_by_member = AsyncMock(return_value=[membership1, membership2]) - result = await broker.resolve_access_context( - db=mock_db, user_principal="john@acme.com", organization_id=organization_id - ) + result = await broker.resolve_access_context( + db=mock_db, user_principal="john@acme.com", organization_id=organization_id + ) - assert result.user_principal == "john@acme.com" - assert result.user_principals == ["user:john@acme.com"] - assert len(result.group_principals) == 2 - assert "group:sp:engineering" in result.group_principals - assert "group:ad:frontend" in result.group_principals - assert len(result.all_principals) == 3 + assert result.user_principal == "john@acme.com" + assert result.user_principals == ["user:john@acme.com"] + assert len(result.group_principals) == 2 + assert "group:sp:engineering" in result.group_principals + assert "group:ad:frontend" in result.group_principals + assert len(result.all_principals) == 3 @pytest.mark.asyncio async def test_resolve_access_context_returns_user_and_group_principals( self, broker, mock_db, organization_id ): """Test that all_principals property combines user and group principals.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - membership = MagicMock() - membership.group_id = "sp:site_owners" - mock_crud.access_control_membership.get_by_member = AsyncMock(return_value=[membership]) + membership = MagicMock() + membership.group_id = "sp:site_owners" + broker._acl_repo.get_by_member = AsyncMock(return_value=[membership]) - result = await broker.resolve_access_context( - db=mock_db, user_principal="admin@acme.com", organization_id=organization_id - ) + result = await broker.resolve_access_context( + db=mock_db, user_principal="admin@acme.com", organization_id=organization_id + ) - all_principals = result.all_principals - assert "user:admin@acme.com" in all_principals - assert "group:sp:site_owners" in all_principals - assert len(all_principals) == 2 + all_principals = result.all_principals + assert "user:admin@acme.com" in all_principals + assert "group:sp:site_owners" in all_principals + assert len(all_principals) == 2 class TestAccessBrokerCollectionScoping: @@ -104,41 +100,35 @@ async def test_resolve_for_collection_filters_by_readable_collection_id( self, broker, mock_db, organization_id ): """Test that collection resolution filters by readable_collection_id.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - membership = MagicMock() - membership.group_id = "sp:engineering" - mock_crud.access_control_membership.get_by_member_and_collection = AsyncMock(return_value=[membership]) - - # Mock get_by_member for group expansion (returns empty for simplicity) - mock_crud.access_control_membership.get_by_member = AsyncMock(return_value=[]) - - # Mock _collection_has_ac_sources to return True - with patch.object(broker, "_collection_has_ac_sources", new=AsyncMock(return_value=True)): - result = await broker.resolve_access_context_for_collection( - db=mock_db, - user_principal="john@acme.com", - readable_collection_id="my-collection", - organization_id=organization_id, - ) - - # Verify CRUD was called with collection filter - mock_crud.access_control_membership.get_by_member_and_collection.assert_called_once_with( - db=mock_db, - member_id="john@acme.com", - member_type="user", - readable_collection_id="my-collection", - organization_id=organization_id, - ) - - assert isinstance(result, AccessContext) - assert "group:sp:engineering" in result.group_principals + membership = MagicMock() + membership.group_id = "sp:engineering" + broker._acl_repo.get_by_member_and_collection = AsyncMock(return_value=[membership]) + broker._acl_repo.get_by_member = AsyncMock(return_value=[]) + + with patch.object(broker, "_collection_has_ac_sources", new=AsyncMock(return_value=True)): + result = await broker.resolve_access_context_for_collection( + db=mock_db, + user_principal="john@acme.com", + readable_collection_id="my-collection", + organization_id=organization_id, + ) + + broker._acl_repo.get_by_member_and_collection.assert_called_once_with( + db=mock_db, + member_id="john@acme.com", + member_type="user", + readable_collection_id="my-collection", + organization_id=organization_id, + ) + + assert isinstance(result, AccessContext) + assert "group:sp:engineering" in result.group_principals @pytest.mark.asyncio async def test_resolve_for_collection_returns_none_when_no_ac_sources( self, broker, mock_db, organization_id ): """Test that resolution returns None if collection has no AC sources.""" - # Mock _collection_has_ac_sources to return False with patch.object(broker, "_collection_has_ac_sources", new=AsyncMock(return_value=False)): result = await broker.resolve_access_context_for_collection( db=mock_db, @@ -154,7 +144,6 @@ async def test_collection_has_ac_sources_returns_false_for_empty_collection( self, broker, mock_db, organization_id ): """Test _collection_has_ac_sources returns False when no memberships exist.""" - # Mock database query to return False mock_result = MagicMock() mock_result.scalar.return_value = False mock_db.execute = AsyncMock(return_value=mock_result) @@ -172,7 +161,6 @@ async def test_collection_has_ac_sources_returns_true_when_memberships_exist( self, broker, mock_db, organization_id ): """Test _collection_has_ac_sources returns True when memberships exist.""" - # Mock database query to return True mock_result = MagicMock() mock_result.scalar.return_value = True mock_db.execute = AsyncMock(return_value=mock_result) @@ -194,86 +182,77 @@ async def test_expand_group_memberships_handles_nested_groups( self, broker, mock_db, organization_id ): """Test expansion of nested group-to-group relationships.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - # Mock nested groups: frontend -> engineering -> all-staff - async def mock_get_by_member(db, member_id, member_type, organization_id): - if member_id == "frontend" and member_type == "group": - parent = MagicMock() - parent.group_id = "engineering" - return [parent] - elif member_id == "engineering" and member_type == "group": - parent = MagicMock() - parent.group_id = "all-staff" - return [parent] - return [] - - mock_crud.access_control_membership.get_by_member = AsyncMock(side_effect=mock_get_by_member) - - result = await broker._expand_group_memberships( - db=mock_db, group_ids=["frontend"], organization_id=organization_id - ) - # Should include all levels: frontend, engineering, all-staff - assert "frontend" in result - assert "engineering" in result - assert "all-staff" in result - assert len(result) == 3 + async def mock_get_by_member(db, member_id, member_type, organization_id): + if member_id == "frontend" and member_type == "group": + parent = MagicMock() + parent.group_id = "engineering" + return [parent] + elif member_id == "engineering" and member_type == "group": + parent = MagicMock() + parent.group_id = "all-staff" + return [parent] + return [] + + broker._acl_repo.get_by_member = AsyncMock(side_effect=mock_get_by_member) + + result = await broker._expand_group_memberships( + db=mock_db, group_ids=["frontend"], organization_id=organization_id + ) + + assert "frontend" in result + assert "engineering" in result + assert "all-staff" in result + assert len(result) == 3 @pytest.mark.asyncio async def test_expand_group_memberships_handles_circular_references( self, broker, mock_db, organization_id ): """Test that circular group references don't cause infinite loops.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: - - async def mock_get_by_member(db, member_id, member_type, organization_id): - # Create circular reference: group-a -> group-b -> group-a - if member_id == "group-a" and member_type == "group": - parent = MagicMock() - parent.group_id = "group-b" - return [parent] - elif member_id == "group-b" and member_type == "group": - parent = MagicMock() - parent.group_id = "group-a" - return [parent] - return [] - - mock_crud.access_control_membership.get_by_member = AsyncMock(side_effect=mock_get_by_member) - - result = await broker._expand_group_memberships( - db=mock_db, group_ids=["group-a"], organization_id=organization_id - ) - # Should handle circular reference gracefully - assert "group-a" in result - assert "group-b" in result - assert len(result) == 2 + async def mock_get_by_member(db, member_id, member_type, organization_id): + if member_id == "group-a" and member_type == "group": + parent = MagicMock() + parent.group_id = "group-b" + return [parent] + elif member_id == "group-b" and member_type == "group": + parent = MagicMock() + parent.group_id = "group-a" + return [parent] + return [] + + broker._acl_repo.get_by_member = AsyncMock(side_effect=mock_get_by_member) + + result = await broker._expand_group_memberships( + db=mock_db, group_ids=["group-a"], organization_id=organization_id + ) + + assert "group-a" in result + assert "group-b" in result + assert len(result) == 2 @pytest.mark.asyncio async def test_expand_group_memberships_respects_max_depth( self, broker, mock_db, organization_id ): """Test that expansion stops at max depth to prevent abuse.""" - with patch("airweave.platform.access_control.broker.crud") as mock_crud: + call_count = 0 - call_count = 0 + async def mock_get_by_member(db, member_id, member_type, organization_id): + nonlocal call_count + call_count += 1 + parent = MagicMock() + parent.group_id = f"level-{call_count}" + return [parent] - async def mock_get_by_member(db, member_id, member_type, organization_id): - nonlocal call_count - call_count += 1 - # Always return a new parent group (infinite nesting) - parent = MagicMock() - parent.group_id = f"level-{call_count}" - return [parent] - - mock_crud.access_control_membership.get_by_member = AsyncMock(side_effect=mock_get_by_member) + broker._acl_repo.get_by_member = AsyncMock(side_effect=mock_get_by_member) - result = await broker._expand_group_memberships( - db=mock_db, group_ids=["level-0"], organization_id=organization_id - ) + result = await broker._expand_group_memberships( + db=mock_db, group_ids=["level-0"], organization_id=organization_id + ) - # Should stop at max depth (10 in implementation) - assert len(result) <= 12 # initial + 10 levels + some tolerance + assert len(result) <= 12 class TestAccessBrokerEntityAccess: @@ -354,4 +333,3 @@ def test_check_entity_access_returns_true_when_empty_viewers(self, broker): result = broker.check_entity_access(entity_access, access_context) assert result is True - From 6b1079ed93213043bb591259f586c2a5166988a3 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 12 Mar 2026 19:58:34 -0700 Subject: [PATCH 08/13] fix: resolve all ruff lint errors in refactored modules --- backend/airweave/api/v1/endpoints/admin.py | 1 + .../domains/access_control/protocols.py | 40 +++++-- .../domains/access_control/repository.py | 10 ++ .../domains/entities/entity_repository.py | 6 ++ .../airweave/domains/entities/protocols.py | 24 +++-- backend/airweave/domains/entities/types.py | 2 + .../sync_pipeline/access_control_pipeline.py | 3 +- .../entity_dispatcher_builder.py | 3 + .../airweave/domains/sync_pipeline/factory.py | 1 + .../handlers/access_control_postgres.py | 1 + .../sync_pipeline/handlers/destination.py | 1 + .../sync_pipeline/handlers/entity_postgres.py | 1 + .../tests/test_acl_membership_tracker.py | 3 +- .../tests/test_acl_reconciliation.py | 42 +++----- .../sync_pipeline/tests/test_chunk_embed.py | 101 +++++++----------- .../tests/test_cleanup_service.py | 19 ++-- .../sync_pipeline/tests/test_config_base.py | 18 ++-- .../tests/test_config_builder.py | 9 +- .../tests/test_destination_handler.py | 10 +- .../tests/test_entity_action_resolver.py | 13 ++- .../tests/test_entity_pipeline.py | 1 - .../sync_pipeline/tests/test_factory.py | 4 +- .../tests/test_progress_relay.py | 48 ++++++--- .../syncs/tests/test_sync_job_service.py | 14 +-- .../platform/access_control/broker.py | 1 + .../platform/converters/txt_converter.py | 58 +++++----- 26 files changed, 235 insertions(+), 199 deletions(-) diff --git a/backend/airweave/api/v1/endpoints/admin.py b/backend/airweave/api/v1/endpoints/admin.py index e3af87a7a..e257a9bf8 100644 --- a/backend/airweave/api/v1/endpoints/admin.py +++ b/backend/airweave/api/v1/endpoints/admin.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors """Admin-only API endpoints for organization management. TODO: Enhance CRUD layer to support bypassing organization filtering cleanly. diff --git a/backend/airweave/domains/access_control/protocols.py b/backend/airweave/domains/access_control/protocols.py index 02c4a9440..cd8381ed7 100644 --- a/backend/airweave/domains/access_control/protocols.py +++ b/backend/airweave/domains/access_control/protocols.py @@ -18,7 +18,9 @@ async def bulk_create( organization_id: UUID, source_connection_id: UUID, source_name: str, - ) -> int: ... + ) -> int: + """Bulk-insert membership rows.""" + ... async def upsert( self, @@ -31,7 +33,9 @@ async def upsert( organization_id: UUID, source_connection_id: UUID, source_name: str, - ) -> None: ... + ) -> None: + """Insert or update a single membership.""" + ... async def delete_by_key( self, @@ -42,7 +46,9 @@ async def delete_by_key( group_id: str, source_connection_id: UUID, organization_id: UUID, - ) -> int: ... + ) -> int: + """Delete a membership by its composite key.""" + ... async def delete_by_group( self, @@ -51,20 +57,26 @@ async def delete_by_group( group_id: str, source_connection_id: UUID, organization_id: UUID, - ) -> int: ... + ) -> int: + """Delete all memberships for a group.""" + ... async def get_by_source_connection( self, db: AsyncSession, source_connection_id: UUID, organization_id: UUID, - ) -> List[AccessControlMembership]: ... + ) -> List[AccessControlMembership]: + """Get all memberships for a source connection.""" + ... async def bulk_delete( self, db: AsyncSession, ids: List[UUID], - ) -> int: ... + ) -> int: + """Bulk-delete memberships by ID.""" + ... async def get_by_member( self, @@ -72,7 +84,9 @@ async def get_by_member( member_id: str, member_type: str, organization_id: UUID, - ) -> List[AccessControlMembership]: ... + ) -> List[AccessControlMembership]: + """Get memberships for a specific member.""" + ... async def get_by_member_and_collection( self, @@ -81,7 +95,9 @@ async def get_by_member_and_collection( member_type: str, readable_collection_id: str, organization_id: UUID, - ) -> List[AccessControlMembership]: ... + ) -> List[AccessControlMembership]: + """Get memberships scoped to a collection.""" + ... async def get_memberships_by_groups( self, @@ -90,11 +106,15 @@ async def get_memberships_by_groups( group_ids: List[str], source_connection_id: UUID, organization_id: UUID, - ) -> List[AccessControlMembership]: ... + ) -> List[AccessControlMembership]: + """Get memberships for a set of group IDs.""" + ... async def delete_by_source_connection( self, db: AsyncSession, source_connection_id: UUID, organization_id: UUID, - ) -> int: ... + ) -> int: + """Delete all memberships for a source connection.""" + ... diff --git a/backend/airweave/domains/access_control/repository.py b/backend/airweave/domains/access_control/repository.py index b07fe6adc..53650b399 100644 --- a/backend/airweave/domains/access_control/repository.py +++ b/backend/airweave/domains/access_control/repository.py @@ -20,6 +20,7 @@ async def bulk_create( source_connection_id: UUID, source_name: str, ) -> int: + """Bulk-insert membership rows.""" return await crud.access_control_membership.bulk_create( db, memberships, organization_id, source_connection_id, source_name ) @@ -36,6 +37,7 @@ async def upsert( source_connection_id: UUID, source_name: str, ) -> None: + """Insert or update a single membership.""" return await crud.access_control_membership.upsert( db, member_id=member_id, @@ -57,6 +59,7 @@ async def delete_by_key( source_connection_id: UUID, organization_id: UUID, ) -> int: + """Delete a membership by its composite key.""" return await crud.access_control_membership.delete_by_key( db, member_id=member_id, @@ -74,6 +77,7 @@ async def delete_by_group( source_connection_id: UUID, organization_id: UUID, ) -> int: + """Delete all memberships for a group.""" return await crud.access_control_membership.delete_by_group( db, group_id=group_id, @@ -87,6 +91,7 @@ async def get_by_source_connection( source_connection_id: UUID, organization_id: UUID, ) -> List[AccessControlMembership]: + """Get all memberships for a source connection.""" return await crud.access_control_membership.get_by_source_connection( db, source_connection_id, organization_id ) @@ -96,6 +101,7 @@ async def bulk_delete( db: AsyncSession, ids: List[UUID], ) -> int: + """Bulk-delete memberships by ID.""" return await crud.access_control_membership.bulk_delete(db, ids) async def get_by_member( @@ -105,6 +111,7 @@ async def get_by_member( member_type: str, organization_id: UUID, ) -> List[AccessControlMembership]: + """Get memberships for a specific member.""" return await crud.access_control_membership.get_by_member( db, member_id, member_type, organization_id ) @@ -117,6 +124,7 @@ async def get_by_member_and_collection( readable_collection_id: str, organization_id: UUID, ) -> List[AccessControlMembership]: + """Get memberships scoped to a collection.""" return await crud.access_control_membership.get_by_member_and_collection( db, member_id, member_type, readable_collection_id, organization_id ) @@ -129,6 +137,7 @@ async def get_memberships_by_groups( source_connection_id: UUID, organization_id: UUID, ) -> List[AccessControlMembership]: + """Get memberships for a set of group IDs.""" return await crud.access_control_membership.get_memberships_by_groups( db, group_ids=group_ids, @@ -142,6 +151,7 @@ async def delete_by_source_connection( source_connection_id: UUID, organization_id: UUID, ) -> int: + """Delete all memberships for a source connection.""" return await crud.access_control_membership.delete_by_source_connection( db, source_connection_id, organization_id ) diff --git a/backend/airweave/domains/entities/entity_repository.py b/backend/airweave/domains/entities/entity_repository.py index fd43a3deb..2d6f9005f 100644 --- a/backend/airweave/domains/entities/entity_repository.py +++ b/backend/airweave/domains/entities/entity_repository.py @@ -14,6 +14,7 @@ class EntityRepository: """Delegates to the crud.entity singleton.""" async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: + """Get all entities for a sync.""" return await crud.entity.get_by_sync_id(db, sync_id) async def bulk_get_by_entity_sync_and_definition( @@ -23,6 +24,7 @@ async def bulk_get_by_entity_sync_and_definition( sync_id: UUID, entity_requests: list[Tuple[str, str]], ) -> Dict[Tuple[str, str], Entity]: + """Bulk-fetch entities by (entity_id, definition) pairs.""" return await crud.entity.bulk_get_by_entity_sync_and_definition( db, sync_id=sync_id, entity_requests=entity_requests ) @@ -34,6 +36,7 @@ async def bulk_create( objs: List[EntityCreate], ctx: Any, ) -> List[Entity]: + """Bulk-insert entity rows.""" return await crud.entity.bulk_create(db, objs=objs, ctx=ctx) async def bulk_update_hash( @@ -42,6 +45,7 @@ async def bulk_update_hash( *, rows: List[Tuple[UUID, str]], ) -> None: + """Bulk-update content hashes.""" return await crud.entity.bulk_update_hash(db, rows=rows) async def bulk_remove( @@ -51,6 +55,7 @@ async def bulk_remove( ids: List[UUID], ctx: Any, ) -> List[Entity]: + """Soft-delete entities by ID.""" return await crud.entity.bulk_remove(db, ids=ids, ctx=ctx) async def bulk_get_by_entity_and_sync( @@ -60,6 +65,7 @@ async def bulk_get_by_entity_and_sync( sync_id: UUID, entity_ids: List[str], ) -> Dict[str, Entity]: + """Bulk-fetch entities by entity_id within a sync.""" return await crud.entity.bulk_get_by_entity_and_sync( db, sync_id=sync_id, entity_ids=entity_ids ) diff --git a/backend/airweave/domains/entities/protocols.py b/backend/airweave/domains/entities/protocols.py index 94e66e725..1c42aefaa 100644 --- a/backend/airweave/domains/entities/protocols.py +++ b/backend/airweave/domains/entities/protocols.py @@ -32,7 +32,9 @@ async def get_counts_per_sync_and_type( class EntityRepositoryProtocol(Protocol): """Entity data access used by the sync pipeline.""" - async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: ... + async def get_by_sync_id(self, db: AsyncSession, sync_id: UUID) -> List[Entity]: + """Get all entities for a sync.""" + ... async def bulk_get_by_entity_sync_and_definition( self, @@ -40,7 +42,9 @@ async def bulk_get_by_entity_sync_and_definition( *, sync_id: UUID, entity_requests: list[Tuple[str, str]], - ) -> Dict[Tuple[str, str], Entity]: ... + ) -> Dict[Tuple[str, str], Entity]: + """Bulk-fetch entities by (entity_id, definition) pairs.""" + ... async def bulk_create( self, @@ -48,14 +52,18 @@ async def bulk_create( *, objs: list, ctx: Any, - ) -> List[Entity]: ... + ) -> List[Entity]: + """Bulk-insert entity rows.""" + ... async def bulk_update_hash( self, db: AsyncSession, *, rows: List[Tuple[UUID, str]], - ) -> None: ... + ) -> None: + """Bulk-update content hashes.""" + ... async def bulk_remove( self, @@ -63,7 +71,9 @@ async def bulk_remove( *, ids: List[UUID], ctx: Any, - ) -> List[Entity]: ... + ) -> List[Entity]: + """Soft-delete entities by ID.""" + ... async def bulk_get_by_entity_and_sync( self, @@ -71,4 +81,6 @@ async def bulk_get_by_entity_and_sync( *, sync_id: UUID, entity_ids: List[str], - ) -> Dict[str, Entity]: ... + ) -> Dict[str, Entity]: + """Bulk-fetch entities by entity_id within a sync.""" + ... diff --git a/backend/airweave/domains/entities/types.py b/backend/airweave/domains/entities/types.py index 7352c0fc5..b39b22235 100644 --- a/backend/airweave/domains/entities/types.py +++ b/backend/airweave/domains/entities/types.py @@ -1,3 +1,5 @@ +"""Types for the entities domain.""" + from pydantic import BaseModel from airweave.core.protocols.registry import BaseRegistryEntry diff --git a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py index f91853509..35f911eaa 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py @@ -36,7 +36,8 @@ def __init__( dispatcher: ACActionDispatcher, tracker: ACLMembershipTracker, acl_repo: AccessControlMembershipRepositoryProtocol, - ): + ) -> None: + """Initialize with resolver, dispatcher, tracker and ACL repository.""" self._resolver = resolver self._dispatcher = dispatcher self._tracker = tracker diff --git a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py index 0e805e275..f2ae9c1a3 100644 --- a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py +++ b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py @@ -22,6 +22,7 @@ def __init__( processor: ChunkEmbedProcessorProtocol, entity_repo: EntityRepositoryProtocol, ) -> None: + """Initialize with processor and entity repository.""" self._processor = processor self._entity_repo = entity_repo @@ -31,6 +32,7 @@ def build( execution_config: Optional[SyncConfig] = None, logger: Optional[ContextualLogger] = None, ) -> EntityActionDispatcher: + """Build a dispatcher with all configured handlers.""" handlers = self._build_handlers(destinations, execution_config, logger) return EntityActionDispatcher(handlers=handlers) @@ -39,6 +41,7 @@ def build_for_cleanup( destinations: List[BaseDestination], logger: Optional[ContextualLogger] = None, ) -> EntityActionDispatcher: + """Build a dispatcher configured for cleanup (no execution config).""" return self.build(destinations=destinations, execution_config=None, logger=logger) def _build_handlers( diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index d6b933ff3..44ba61f98 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -63,6 +63,7 @@ def __init__( acl_repo: AccessControlMembershipRepositoryProtocol, processor: ChunkEmbedProcessorProtocol, ) -> None: + """Initialize with all required service and repository dependencies.""" self._sc_repo = sc_repo self._event_bus = event_bus self._usage_checker = usage_checker diff --git a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py index 8fcbada88..6e3fc397c 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py @@ -26,6 +26,7 @@ class ACPostgresHandler(ACActionHandler): """Persists access control memberships to PostgreSQL.""" def __init__(self, acl_repo: AccessControlMembershipRepositoryProtocol) -> None: + """Initialize with ACL membership repository.""" self._acl_repo = acl_repo @property diff --git a/backend/airweave/domains/sync_pipeline/handlers/destination.py b/backend/airweave/domains/sync_pipeline/handlers/destination.py index 846749493..e76998f94 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/destination.py +++ b/backend/airweave/domains/sync_pipeline/handlers/destination.py @@ -45,6 +45,7 @@ def __init__( destinations: List[BaseDestination], processor: ChunkEmbedProcessorProtocol, ) -> None: + """Initialize with destination list and chunk/embed processor.""" self._destinations = destinations self._processor = processor diff --git a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py index feb0584bb..b16fd6ab4 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py +++ b/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py @@ -30,6 +30,7 @@ class EntityPostgresHandler(EntityActionHandler): """Handler for PostgreSQL entity metadata.""" def __init__(self, entity_repo: EntityRepositoryProtocol) -> None: + """Initialize with entity repository.""" self._entity_repo = entity_repo @property diff --git a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py index a3a60008a..0ca7e4e1b 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py @@ -1,9 +1,10 @@ """Unit tests for ACL Membership Tracker.""" -import pytest from unittest.mock import MagicMock from uuid import uuid4 +import pytest + from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker diff --git a/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py b/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py index dbe6eab10..01c663003 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py @@ -18,8 +18,10 @@ import pytest -from airweave.platform.access_control.schemas import ACLChangeType, MembershipChange from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.platform.access_control.schemas import ACLChangeType, MembershipChange + +_GET_DB_CTX = "airweave.domains.sync_pipeline.access_control_pipeline.get_db_context" # --------------------------------------------------------------------------- @@ -113,9 +115,7 @@ async def test_applies_adds_and_removes(self): pipeline._acl_repo.upsert = AsyncMock() pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes(db, result, source, ctx) assert adds == 2 assert removes == 1 @@ -146,9 +146,7 @@ async def test_basic_flags_does_not_reconcile(self): pipeline._acl_repo.upsert = AsyncMock() pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes(db, result, source, ctx) assert adds == 2 assert removes == 0 @@ -172,9 +170,7 @@ async def test_upsert_passes_correct_fields(self): pipeline._acl_repo.upsert = AsyncMock() - await pipeline._apply_membership_changes( - db, result, source, ctx - ) + await pipeline._apply_membership_changes(db, result, source, ctx) pipeline._acl_repo.upsert.assert_called_once_with( db, @@ -200,9 +196,7 @@ async def test_empty_changes_returns_zero(self): pipeline._acl_repo.upsert = AsyncMock() pipeline._acl_repo.delete_by_key = AsyncMock() - adds, removes = await pipeline._apply_membership_changes( - db, result, source, ctx - ) + adds, removes = await pipeline._apply_membership_changes(db, result, source, ctx) assert adds == 0 assert removes == 0 @@ -221,9 +215,7 @@ async def test_deleted_groups_remove_all_memberships(self): """Deleted AD groups should have all their memberships removed via delete_by_group.""" pipeline = _make_pipeline() ctx = FakeSyncContext() - runtime = FakeRuntime( - cursor=FakeCursor(data={"acl_dirsync_cookie": "old_cookie"}) - ) + runtime = FakeRuntime(cursor=FakeCursor(data={"acl_dirsync_cookie": "old_cookie"})) source = SimpleNamespace( _short_name="sp2019v2", get_acl_changes=AsyncMock(), @@ -238,7 +230,7 @@ async def test_deleted_groups_remove_all_memberships(self): ) source.get_acl_changes.return_value = result - with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch(_GET_DB_CTX) as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -255,9 +247,7 @@ async def test_no_changes_returns_zero_and_updates_cookie(self): """When DirSync reports zero changes, return 0 but still advance the cookie.""" pipeline = _make_pipeline() ctx = FakeSyncContext() - runtime = FakeRuntime( - cursor=FakeCursor(data={"acl_dirsync_cookie": "old_cookie"}) - ) + runtime = FakeRuntime(cursor=FakeCursor(data={"acl_dirsync_cookie": "old_cookie"})) source = SimpleNamespace( _short_name="sp2019v2", get_acl_changes=AsyncMock(), @@ -300,9 +290,7 @@ async def test_full_flow_with_adds_removes_and_deletes(self): """End-to-end: ADDs + REMOVEs + deleted groups in one incremental sync.""" pipeline = _make_pipeline() ctx = FakeSyncContext() - runtime = FakeRuntime( - cursor=FakeCursor(data={"acl_dirsync_cookie": "old"}) - ) + runtime = FakeRuntime(cursor=FakeCursor(data={"acl_dirsync_cookie": "old"})) source = SimpleNamespace( _short_name="sp2019v2", get_acl_changes=AsyncMock(), @@ -321,7 +309,7 @@ async def test_full_flow_with_adds_removes_and_deletes(self): ) source.get_acl_changes.return_value = result - with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch(_GET_DB_CTX) as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) @@ -344,9 +332,7 @@ async def test_basic_flags_does_not_reconcile_in_full_flow(self): """With BASIC flags, _process_incremental only applies ADDs (no reconciliation).""" pipeline = _make_pipeline() ctx = FakeSyncContext() - runtime = FakeRuntime( - cursor=FakeCursor(data={"acl_dirsync_cookie": "old"}) - ) + runtime = FakeRuntime(cursor=FakeCursor(data={"acl_dirsync_cookie": "old"})) source = SimpleNamespace( _short_name="sp2019v2", get_acl_changes=AsyncMock(), @@ -362,7 +348,7 @@ async def test_basic_flags_does_not_reconcile_in_full_flow(self): ) source.get_acl_changes.return_value = result - with patch("airweave.domains.sync_pipeline.access_control_pipeline.get_db_context") as mock_db_ctx: + with patch(_GET_DB_CTX) as mock_db_ctx: mock_db = MagicMock() mock_db_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_db_ctx.return_value.__aexit__ = AsyncMock(return_value=False) diff --git a/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py index 8ea1d9a10..aef34f233 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py @@ -6,6 +6,9 @@ from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor +_TEXT_BUILDER = "airweave.domains.sync_pipeline.processors.chunk_embed.text_builder" +_SEMANTIC_CHUNKER = "airweave.platform.chunkers.semantic.SemanticChunker" + @pytest.fixture def processor(): @@ -64,26 +67,24 @@ async def test_chunk_textual_entities_uses_semantic_chunker( self, processor, mock_sync_context, mock_runtime, mock_entity ): """Test textual entities routed to SemanticChunker.""" - with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ - patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockSemanticChunker, \ - patch.object(processor, '_embed_entities', new_callable=AsyncMock): - - # Setup mocks + with ( + patch(_TEXT_BUILDER) as mock_builder, + patch(_SEMANTIC_CHUNKER) as MockSemanticChunker, + patch.object(processor, "_embed_entities", new_callable=AsyncMock), + ): mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) mock_chunker = MockSemanticChunker.return_value - mock_chunker.chunk_batch = AsyncMock(return_value=[ - [{"text": "Chunk 1"}, {"text": "Chunk 2"}] - ]) + mock_chunker.chunk_batch = AsyncMock( + return_value=[[{"text": "Chunk 1"}, {"text": "Chunk 2"}]] + ) - result = await processor.process([mock_entity], mock_sync_context, mock_runtime) + await processor.process([mock_entity], mock_sync_context, mock_runtime) # Verify SemanticChunker was called mock_chunker.chunk_batch.assert_called_once() @pytest.mark.asyncio - async def test_multiply_entities_creates_chunk_suffix( - self, processor, mock_sync_context - ): + async def test_multiply_entities_creates_chunk_suffix(self, processor, mock_sync_context): """Test chunk entity creation with proper ID suffix.""" # Create mock entity mock_entity = MagicMock() @@ -104,9 +105,7 @@ def create_chunk_entity(deep=False): mock_entity.model_copy = MagicMock(side_effect=create_chunk_entity) - chunks = [ - [{"text": "Chunk 0"}, {"text": "Chunk 1"}] - ] + chunks = [[{"text": "Chunk 0"}, {"text": "Chunk 1"}]] result = processor._multiply_entities([mock_entity], chunks, mock_sync_context) @@ -116,9 +115,7 @@ def create_chunk_entity(deep=False): assert "__chunk_1" in result[1].entity_id @pytest.mark.asyncio - async def test_multiply_entities_sets_chunk_index( - self, processor, mock_sync_context - ): + async def test_multiply_entities_sets_chunk_index(self, processor, mock_sync_context): """Test chunk index set correctly.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" @@ -141,9 +138,7 @@ def create_chunk_entity(deep=False): assert result[0].airweave_system_metadata.chunk_index == 0 @pytest.mark.asyncio - async def test_multiply_entities_skips_empty_chunks( - self, processor, mock_sync_context - ): + async def test_multiply_entities_skips_empty_chunks(self, processor, mock_sync_context): """Test empty chunks are filtered out.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" @@ -159,9 +154,7 @@ def create_chunk_entity(deep=False): mock_entity.model_copy = MagicMock(side_effect=create_chunk_entity) - chunks = [ - [{"text": "Valid"}, {"text": ""}, {"text": " "}, {"text": "Another"}] - ] + chunks = [[{"text": "Valid"}, {"text": ""}, {"text": " "}, {"text": "Another"}]] result = processor._multiply_entities([mock_entity], chunks, mock_sync_context) @@ -169,9 +162,7 @@ def create_chunk_entity(deep=False): assert len(result) == 2 @pytest.mark.asyncio - async def test_embed_entities_calls_both_embedders( - self, processor, mock_runtime - ): + async def test_embed_entities_calls_both_embedders(self, processor, mock_runtime): """Test both dense and sparse embedders are called.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test content" @@ -193,9 +184,7 @@ async def test_embed_entities_calls_both_embedders( mock_runtime.sparse_embedder.embed_many.assert_called_once() @pytest.mark.asyncio - async def test_embed_entities_assigns_embeddings( - self, processor, mock_runtime - ): + async def test_embed_entities_assigns_embeddings(self, processor, mock_runtime): """Test embeddings assigned to entity system metadata.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" @@ -221,17 +210,17 @@ async def test_embed_entities_assigns_embeddings( assert mock_entity.airweave_system_metadata.sparse_embedding == sparse_embedding @pytest.mark.asyncio - async def test_embed_entities_uses_full_json_for_sparse( - self, processor, mock_runtime - ): + async def test_embed_entities_uses_full_json_for_sparse(self, processor, mock_runtime): """Test sparse embedder receives full entity JSON.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() - mock_entity.model_dump = MagicMock(return_value={ - "entity_id": "test-123", - "name": "Test Entity", - }) + mock_entity.model_dump = MagicMock( + return_value={ + "entity_id": "test-123", + "name": "Test Entity", + } + ) chunk_entities = [mock_entity] @@ -249,13 +238,12 @@ async def test_embed_entities_uses_full_json_for_sparse( # Verify it's JSON import json + parsed = json.loads(call_args[0]) assert "entity_id" in parsed @pytest.mark.asyncio - async def test_embed_entities_validates_embeddings_exist( - self, processor, mock_runtime - ): + async def test_embed_entities_validates_embeddings_exist(self, processor, mock_runtime): """Test validation that all entities have embeddings.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" @@ -278,9 +266,7 @@ async def test_embed_entities_validates_embeddings_exist( assert "no dense embedding" in str(exc_info.value).lower() @pytest.mark.asyncio - async def test_full_pipeline_with_mocks( - self, processor, mock_sync_context, mock_runtime - ): + async def test_full_pipeline_with_mocks(self, processor, mock_sync_context, mock_runtime): """Test full pipeline with all mocked dependencies.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" @@ -306,20 +292,15 @@ def create_chunk(deep=False): mock_runtime.dense_embedder.embed_many = AsyncMock( return_value=[dense_result, dense_result_2] ) - mock_runtime.sparse_embedder.embed_many = AsyncMock( - return_value=[MagicMock(), MagicMock()] - ) + mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock(), MagicMock()]) - with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ - patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: - - # Setup mocks + with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) mock_chunker = MockChunker.return_value - mock_chunker.chunk_batch = AsyncMock(return_value=[ - [{"text": "Chunk 1"}, {"text": "Chunk 2"}] - ]) + mock_chunker.chunk_batch = AsyncMock( + return_value=[[{"text": "Chunk 1"}, {"text": "Chunk 2"}]] + ) result = await processor.process([mock_entity], mock_sync_context, mock_runtime) @@ -354,9 +335,7 @@ def create_chunk(deep=False): mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ - patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: - + with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) mock_chunker = MockChunker.return_value @@ -368,16 +347,14 @@ def create_chunk(deep=False): assert mock_entity.textual_representation is None @pytest.mark.asyncio - async def test_skips_entities_without_text( - self, processor, mock_sync_context, mock_runtime - ): + async def test_skips_entities_without_text(self, processor, mock_sync_context, mock_runtime): """Test entities with no textual_representation are skipped.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" mock_entity.textual_representation = None # No text mock_entity.airweave_system_metadata = MagicMock() - with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder: + with patch(_TEXT_BUILDER) as mock_builder: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) result = await processor.process([mock_entity], mock_sync_context, mock_runtime) @@ -395,9 +372,7 @@ async def test_handles_empty_chunks_from_chunker( mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() - with patch('airweave.domains.sync_pipeline.processors.chunk_embed.text_builder') as mock_builder, \ - patch('airweave.platform.chunkers.semantic.SemanticChunker') as MockChunker: - + with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) mock_chunker = MockChunker.return_value diff --git a/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py index 15da6f6e8..c211aa1c4 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_cleanup_service.py @@ -10,15 +10,15 @@ import os import tempfile from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from airweave.platform.entities._airweave_field import AirweaveField -from airweave.platform.entities._base import BaseEntity, DeletionEntity, FileEntity from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.pipeline.cleanup_service import cleanup_service +from airweave.platform.entities._airweave_field import AirweaveField +from airweave.platform.entities._base import BaseEntity, DeletionEntity, FileEntity # Test entity classes @@ -70,10 +70,7 @@ def create_file_entity_with_temp_file(temp_dir: str, filename: str) -> _TestFile Path(file_path).write_text("test content") entity = _TestFileEntity( - file_id=str(uuid4()), - name=filename, - url=f"https://example.com/{filename}", - breadcrumbs=[] + file_id=str(uuid4()), name=filename, url=f"https://example.com/{filename}", breadcrumbs=[] ) entity.local_path = file_path return entity @@ -175,7 +172,9 @@ async def test_cleanup_ignores_deletes(temp_dir, mock_sync_context): DeletionEntity is not a FileEntity, so it should be skipped. """ # Create a deletion entity (no file on disk) - deletion_entity = _TestDeletionEntity(deletion_id=str(uuid4()), label="deleted-item", breadcrumbs=[]) + deletion_entity = _TestDeletionEntity( + deletion_id=str(uuid4()), label="deleted-item", breadcrumbs=[] + ) partitions = { "inserts": [], @@ -199,7 +198,9 @@ async def test_cleanup_ignores_non_file_entities(temp_dir, mock_sync_context): """Test that non-FileEntity types are ignored.""" # Mix FileEntity with non-FileEntity file_entity = create_file_entity_with_temp_file(temp_dir, "file.txt") - non_file_entity = _TestNonFileEntity(test_id=str(uuid4()), name="non-file-entity", breadcrumbs=[]) + non_file_entity = _TestNonFileEntity( + test_id=str(uuid4()), name="non-file-entity", breadcrumbs=[] + ) partitions = { "inserts": [file_entity, non_file_entity], diff --git a/backend/airweave/domains/sync_pipeline/tests/test_config_base.py b/backend/airweave/domains/sync_pipeline/tests/test_config_base.py index 69ec1ed69..0ca614782 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_config_base.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_config_base.py @@ -162,16 +162,12 @@ def test_no_conflict_with_different_destinations(self): def test_no_conflict_with_only_target(self): """Test that only target destinations doesn't raise.""" - config = SyncConfig( - destinations=DestinationConfig(target_destinations=[uuid4()]) - ) + config = SyncConfig(destinations=DestinationConfig(target_destinations=[uuid4()])) assert config is not None def test_no_conflict_with_only_exclude(self): """Test that only exclude destinations doesn't raise.""" - config = SyncConfig( - destinations=DestinationConfig(exclude_destinations=[uuid4()]) - ) + config = SyncConfig(destinations=DestinationConfig(exclude_destinations=[uuid4()])) assert config is not None @@ -236,10 +232,12 @@ def test_merge_overwrites_values(self): def test_merge_deep_nested(self): """Test deep merge of nested values.""" config = SyncConfig.default() - merged = config.merge_with({ - "handlers": {"enable_postgres_handler": False}, - "behavior": {"skip_hash_comparison": True}, - }) + merged = config.merge_with( + { + "handlers": {"enable_postgres_handler": False}, + "behavior": {"skip_hash_comparison": True}, + } + ) assert merged.handlers.enable_postgres_handler is False assert merged.behavior.skip_hash_comparison is True assert merged.handlers.enable_vector_handlers is True # Preserved diff --git a/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py index 28e24141d..8aadb7e27 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_config_builder.py @@ -6,7 +6,6 @@ from airweave.domains.sync_pipeline.config.base import ( BehaviorConfig, CursorConfig, - DestinationConfig, HandlerConfig, SyncConfig, ) @@ -110,9 +109,7 @@ def test_partial_override_preserves_other_sections(self): """Test that overriding one section preserves other sections.""" with _clean_env(): config = SyncConfigBuilder.build( - job_overrides=SyncConfig( - behavior=BehaviorConfig(skip_hash_comparison=True) - ) + job_overrides=SyncConfig(behavior=BehaviorConfig(skip_hash_comparison=True)) ) assert config.behavior.skip_hash_comparison is True assert config.handlers.enable_vector_handlers is True # Other section default @@ -144,8 +141,6 @@ def test_from_db_json(self): } with _clean_env(): - config = SyncConfigBuilder.build( - collection_overrides=SyncConfig(**db_json) - ) + config = SyncConfigBuilder.build(collection_overrides=SyncConfig(**db_json)) assert config.handlers.enable_vector_handlers is False assert config.behavior.skip_hash_comparison is True diff --git a/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py index 288b897a9..ff0910faf 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py @@ -15,6 +15,8 @@ from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler +_ASYNC_SLEEP = "airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep" + def _make_mock_destination(soft_fail=False): """Create a mock destination with required attributes.""" @@ -55,7 +57,7 @@ async def failing_operation(): call_count += 1 raise TimeoutError("feed timed out") - with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch(_ASYNC_SLEEP, new_callable=AsyncMock): with pytest.raises(SyncFailureError, match="Destination unavailable"): await handler._execute_with_retry( operation=failing_operation, @@ -82,7 +84,7 @@ async def failing_operation(): call_count += 1 raise asyncio.TimeoutError() - with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch(_ASYNC_SLEEP, new_callable=AsyncMock): with pytest.raises(SyncFailureError, match="Destination unavailable"): await handler._execute_with_retry( operation=failing_operation, @@ -110,7 +112,7 @@ async def flaky_operation(): raise TimeoutError("temporary failure") return "success" - with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch(_ASYNC_SLEEP, new_callable=AsyncMock): result = await handler._execute_with_retry( operation=flaky_operation, operation_name="insert_MockDestination", @@ -132,7 +134,7 @@ async def test_retry_logs_warning_on_each_failure(self): async def failing_operation(): raise TimeoutError("feed timed out") - with patch("airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep", new_callable=AsyncMock): + with patch(_ASYNC_SLEEP, new_callable=AsyncMock): with pytest.raises(SyncFailureError): await handler._execute_with_retry( operation=failing_operation, diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py index 52600f3b6..5de9f8e75 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py @@ -6,19 +6,18 @@ import pytest from airweave.domains.sync_pipeline.entity_action_resolver import EntityActionResolver +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.types.entity_actions import ( + EntityInsertAction, + EntityKeepAction, + EntityUpdateAction, +) from airweave.platform.entities._airweave_field import AirweaveField from airweave.platform.entities._base import ( AirweaveSystemMetadata, BaseEntity, DeletionEntity, ) -from airweave.domains.sync_pipeline.types.entity_actions import ( - EntityInsertAction, - EntityKeepAction, - EntityUpdateAction, -) -from airweave.domains.sync_pipeline.exceptions import SyncFailureError - # --------------------------------------------------------------------------- # Helpers diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py index 4aa665100..eff579a5c 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py @@ -7,7 +7,6 @@ from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline - # --------------------------------------------------------------------------- # Constructor # --------------------------------------------------------------------------- diff --git a/backend/airweave/domains/sync_pipeline/tests/test_factory.py b/backend/airweave/domains/sync_pipeline/tests/test_factory.py index 6553b89b8..6cee8e685 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_factory.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_factory.py @@ -123,9 +123,7 @@ async def test_create_orchestrator_passes_entity_repo_to_pipeline(): db = AsyncMock() with ( - patch( - "airweave.domains.sync_pipeline.factory.SyncContextBuilder" - ) as mock_sc_builder, + patch("airweave.domains.sync_pipeline.factory.SyncContextBuilder") as mock_sc_builder, patch( "airweave.domains.sync_pipeline.factory.EntityDispatcherBuilder" ) as mock_disp_builder, diff --git a/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py b/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py index 09a8e72f1..c6b8dd2b2 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_progress_relay.py @@ -6,7 +6,6 @@ import pytest from airweave.adapters.pubsub.fake import FakePubSub -from airweave.core.events.enums import SyncEventType from airweave.core.events.sync import ( AccessControlMembershipBatchProcessedEvent, EntityBatchProcessedEvent, @@ -29,9 +28,7 @@ def _make_relay(): return relay, pubsub -def _running_event( - sync_id=SYNC_ID, job_id=JOB_ID -) -> SyncLifecycleEvent: +def _running_event(sync_id=SYNC_ID, job_id=JOB_ID) -> SyncLifecycleEvent: return SyncLifecycleEvent.running( organization_id=ORG_ID, sync_id=sync_id, @@ -45,8 +42,13 @@ def _running_event( def _batch_event( - inserted=5, updated=3, deleted=1, kept=2, - job_id=JOB_ID, billable=True, type_breakdown=None, + inserted=5, + updated=3, + deleted=1, + kept=2, + job_id=JOB_ID, + billable=True, + type_breakdown=None, ) -> EntityBatchProcessedEvent: return EntityBatchProcessedEvent( organization_id=ORG_ID, @@ -166,10 +168,15 @@ async def test_accumulates_type_breakdown(self): "FileEntity": TypeActionCounts(inserted=5, updated=2, deleted=0, kept=1), "FolderEntity": TypeActionCounts(inserted=3, updated=0, deleted=0, kept=0), } - await relay.handle(_batch_event( - inserted=8, updated=2, deleted=0, kept=1, - type_breakdown=breakdown, - )) + await relay.handle( + _batch_event( + inserted=8, + updated=2, + deleted=0, + kept=1, + type_breakdown=breakdown, + ) + ) session = relay._sessions[JOB_ID] assert "FileEntity" in session.type_counts @@ -184,8 +191,12 @@ async def test_type_breakdown_merges_across_batches(self): batch1 = {"FileEntity": TypeActionCounts(inserted=3, updated=1, deleted=0, kept=0)} batch2 = {"FileEntity": TypeActionCounts(inserted=2, updated=4, deleted=0, kept=0)} - await relay.handle(_batch_event(inserted=3, updated=1, deleted=0, kept=0, type_breakdown=batch1)) - await relay.handle(_batch_event(inserted=2, updated=4, deleted=0, kept=0, type_breakdown=batch2)) + await relay.handle( + _batch_event(inserted=3, updated=1, deleted=0, kept=0, type_breakdown=batch1) + ) + await relay.handle( + _batch_event(inserted=2, updated=4, deleted=0, kept=0, type_breakdown=batch2) + ) session = relay._sessions[JOB_ID] assert session.type_counts["FileEntity"].inserted == 5 @@ -417,10 +428,15 @@ async def test_named_counts_sums_inserted_updated_kept(self): breakdown = { "FileEntity": TypeActionCounts(inserted=5, updated=2, deleted=3, kept=10), } - await relay.handle(_batch_event( - inserted=5, updated=2, deleted=3, kept=10, - type_breakdown=breakdown, - )) + await relay.handle( + _batch_event( + inserted=5, + updated=2, + deleted=3, + kept=10, + type_breakdown=breakdown, + ) + ) session = relay._sessions[JOB_ID] assert session.named_counts["FileEntity"] == 17 # 5 + 2 + 10 diff --git a/backend/airweave/domains/syncs/tests/test_sync_job_service.py b/backend/airweave/domains/syncs/tests/test_sync_job_service.py index a7a5ec9cd..05013aa63 100644 --- a/backend/airweave/domains/syncs/tests/test_sync_job_service.py +++ b/backend/airweave/domains/syncs/tests/test_sync_job_service.py @@ -13,9 +13,9 @@ import pytest from airweave.core.shared_models import SyncJobStatus +from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats from airweave.domains.syncs.sync_job_service import SyncJobService from airweave.domains.syncs.types import StatsUpdate, TimestampUpdate -from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats NOW = datetime(2024, 6, 15, 12, 0, 0, tzinfo=timezone.utc) @@ -50,7 +50,11 @@ class StatsCase: StatsCase( name="mixed_values", stats=SyncStats( - inserted=5, updated=3, deleted=1, kept=10, skipped=2, + inserted=5, + updated=3, + deleted=1, + kept=10, + skipped=2, entities_encountered={"Document": 15, "Image": 6}, ), expected=StatsUpdate( @@ -86,7 +90,7 @@ class TimestampCase: error: Optional[str] = None expected: Optional[TimestampUpdate] = None - def __post_init__(self): + def __post_init__(self): # noqa: D105 if self.expected is None: self.expected = TimestampUpdate() @@ -212,9 +216,7 @@ async def test_update_status(case: UpdateStatusCase): mock_ctx.organization = MagicMock() mock_ctx.organization.id = uuid4() - with patch( - "airweave.domains.syncs.sync_job_service.get_db_context" - ) as mock_ctx_mgr: + with patch("airweave.domains.syncs.sync_job_service.get_db_context") as mock_ctx_mgr: mock_ctx_mgr.return_value.__aenter__ = AsyncMock(return_value=mock_db) mock_ctx_mgr.return_value.__aexit__ = AsyncMock(return_value=False) diff --git a/backend/airweave/platform/access_control/broker.py b/backend/airweave/platform/access_control/broker.py index e56173811..918a62ff4 100644 --- a/backend/airweave/platform/access_control/broker.py +++ b/backend/airweave/platform/access_control/broker.py @@ -14,6 +14,7 @@ class AccessBroker: """Resolves user access context by expanding group memberships.""" def __init__(self, acl_repo: AccessControlMembershipRepositoryProtocol) -> None: + """Initialize with ACL membership repository.""" self._acl_repo = acl_repo async def resolve_access_context( diff --git a/backend/airweave/platform/converters/txt_converter.py b/backend/airweave/platform/converters/txt_converter.py index f4da0cdd2..6592ea0a2 100644 --- a/backend/airweave/platform/converters/txt_converter.py +++ b/backend/airweave/platform/converters/txt_converter.py @@ -71,6 +71,31 @@ async def _convert_one(path: str): return results + @staticmethod + def _try_chardet_decode(raw_bytes: bytes, path: str) -> str | None: + """Attempt to decode bytes using chardet-detected encoding. + + Returns decoded text on success, None otherwise. + """ + try: + import chardet + + detection = chardet.detect(raw_bytes[:100000]) + if not detection or detection.get("confidence", 0) <= 0.7: + return None + detected_encoding = detection["encoding"] + if not detected_encoding: + return None + text = raw_bytes.decode(detected_encoding) + if text.count("\ufffd") == 0: + logger.debug(f"Detected encoding {detected_encoding} for {os.path.basename(path)}") + return text + except (UnicodeDecodeError, LookupError): + pass + except ImportError: + logger.debug("chardet not available, falling back to UTF-8 with ignore") + return None + async def _convert_plain_text(self, path: str) -> str: """Read plain text file with encoding detection. @@ -83,44 +108,23 @@ async def _convert_plain_text(self, path: str) -> str: Raises: EntityProcessingError: If file contains excessive binary/corrupted data """ - # Read raw bytes for encoding detection async with aiofiles.open(path, "rb") as f: raw_bytes = await f.read() if not raw_bytes: return "" - # Try UTF-8 first (most common) try: text = raw_bytes.decode("utf-8") - replacement_count = text.count("\ufffd") - if replacement_count == 0: + if text.count("\ufffd") == 0: return text except UnicodeDecodeError: pass - # Try encoding detection - try: - import chardet - - detection = chardet.detect(raw_bytes[:100000]) # Sample first 100KB - if detection and detection.get("confidence", 0) > 0.7: - detected_encoding = detection["encoding"] - if detected_encoding: - try: - text = raw_bytes.decode(detected_encoding) - replacement_count = text.count("\ufffd") - if replacement_count == 0: - logger.debug( - f"Detected encoding {detected_encoding} for {os.path.basename(path)}" - ) - return text - except (UnicodeDecodeError, LookupError): - pass - except ImportError: - logger.debug("chardet not available, falling back to UTF-8 with ignore") + chardet_result = self._try_chardet_decode(raw_bytes, path) + if chardet_result is not None: + return chardet_result - # Fallback: decode with replace to create U+FFFD for validation text = raw_bytes.decode("utf-8", errors="replace") replacement_count = text.count("\ufffd") @@ -128,7 +132,6 @@ async def _convert_plain_text(self, path: str) -> str: text_length = len(text) replacement_ratio = replacement_count / text_length if text_length > 0 else 0 - # Warn if high replacement ratio if replacement_ratio > 0.25 or replacement_count > 5000: logger.warning( f"File {os.path.basename(path)} contains {replacement_count} " @@ -233,7 +236,8 @@ def _read_and_format(): replacement_count = raw.count("\ufffd") if replacement_count > 100: # More lenient for fallback raise EntityProcessingError( - f"XML contains excessive binary data ({replacement_count} replacement chars)" + f"XML contains excessive binary data " + f"({replacement_count} replacement chars)" ) return f"```xml\n{raw}\n```" if raw.strip() else None From 69531618df5f03175d282b01938ed0f9458306b3 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Fri, 13 Mar 2026 10:42:58 -0700 Subject: [PATCH 09/13] refactor: protocolize ACL pipeline deps, move AccessBroker to domain with container DI - Add ACActionResolverProtocol + ACActionDispatcherProtocol; update AccessControlPipeline to depend on protocols instead of concrete classes - Eliminate isinstance check in EntityActionDispatcher by splitting into explicit destination_handlers + metadata_handler constructor params - Move lazy imports in SyncFactory to top-level (no circular dep risk) - Move AccessBroker from platform to domains/access_control with protocol, delete module-level singleton, wire via Container + factory - Update all consumers (AccessControlFilter, admin.py, SearchService) to receive broker from container instead of importing singleton - Add FakeAccessBroker, move broker tests to domains/, register in conftest --- backend/airweave/api/v1/endpoints/admin.py | 4 +- backend/airweave/core/container/container.py | 4 + backend/airweave/core/container/factory.py | 3 + .../access_control/broker.py | 105 +----- .../domains/access_control/fakes/broker.py | 60 ++++ .../domains/access_control/protocols.py | 35 +- .../sync_pipeline/access_control_pipeline.py | 10 +- .../sync_pipeline/entity_action_dispatcher.py | 29 +- .../entity_dispatcher_builder.py | 35 +- .../airweave/domains/sync_pipeline/factory.py | 11 +- .../domains/sync_pipeline/protocols.py | 26 ++ .../platform/access_control/__init__.py | 3 +- backend/airweave/search/factory.py | 2 + .../operations/access_control_filter.py | 14 +- backend/airweave/search/service.py | 7 +- backend/conftest.py | 10 + .../unit/api/test_admin_user_principals.py | 24 +- .../unit/domains/access_control/__init__.py | 0 .../access_control/test_broker.py | 5 +- .../operations/test_access_control_filter.py | 307 +++++++++--------- 20 files changed, 376 insertions(+), 318 deletions(-) rename backend/airweave/{platform => domains}/access_control/broker.py (55%) create mode 100644 backend/airweave/domains/access_control/fakes/broker.py create mode 100644 backend/tests/unit/domains/access_control/__init__.py rename backend/tests/unit/{platform => domains}/access_control/test_broker.py (99%) diff --git a/backend/airweave/api/v1/endpoints/admin.py b/backend/airweave/api/v1/endpoints/admin.py index e257a9bf8..e9d76718a 100644 --- a/backend/airweave/api/v1/endpoints/admin.py +++ b/backend/airweave/api/v1/endpoints/admin.py @@ -35,6 +35,7 @@ from airweave.core.shared_models import FeatureFlag as FeatureFlagEnum from airweave.crud.crud_organization_billing import organization_billing from airweave.db.unit_of_work import UnitOfWork +from airweave.domains.access_control.protocols import AccessBrokerProtocol from airweave.domains.billing.operations import BillingOperations from airweave.domains.billing.repository import ( BillingPeriodRepository, @@ -1410,14 +1411,13 @@ async def admin_get_user_principals( db: AsyncSession = Depends(deps.get_db), ctx: ApiContext = Depends(deps.get_context), collection_repo: CollectionRepositoryProtocol = Inject(CollectionRepositoryProtocol), + access_broker: AccessBrokerProtocol = Inject(AccessBrokerProtocol), ) -> List[str]: """Admin-only: Get the resolved access principals for a user in a collection. Returns all principals (user + group memberships) that would be used for access control filtering when the user searches the collection. """ - from airweave.platform.access_control.broker import access_broker - _require_admin_permission(ctx, FeatureFlagEnum.API_KEY_ADMIN_SYNC) collection = await collection_repo.get_by_readable_id( diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index 02da8978e..dc8f664f6 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -30,6 +30,7 @@ ) from airweave.core.protocols.identity import IdentityProvider from airweave.core.protocols.payment import PaymentGatewayProtocol +from airweave.domains.access_control.protocols import AccessBrokerProtocol from airweave.domains.auth_provider.protocols import ( AuthProviderRegistryProtocol, AuthProviderServiceProtocol, @@ -194,6 +195,9 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): entity_repo: EntityRepositoryProtocol + # Access control broker (resolves user → group principals) + access_broker: AccessBrokerProtocol + # Temporal domain temporal_workflow_service: TemporalWorkflowServiceProtocol temporal_schedule_service: TemporalScheduleServiceProtocol diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 1890348b3..bc12e1206 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -44,6 +44,7 @@ from airweave.core.protocols.webhooks import WebhookPublisher from airweave.core.redis_client import redis_client from airweave.db.session import health_check_engine +from airweave.domains.access_control.broker import AccessBroker from airweave.domains.access_control.repository import AccessControlMembershipRepository from airweave.domains.auth_provider.registry import AuthProviderRegistry from airweave.domains.auth_provider.service import AuthProviderService @@ -384,6 +385,7 @@ def create_container(settings: Settings) -> Container: # Access control membership repo + chunk embed processor # ----------------------------------------------------------------- acl_membership_repo = AccessControlMembershipRepository() + access_broker = AccessBroker(acl_repo=acl_membership_repo) chunk_embed_processor = ChunkEmbedProcessor() # ----------------------------------------------------------------- @@ -511,6 +513,7 @@ def create_container(settings: Settings) -> Container: sync_lifecycle=sync_deps["sync_lifecycle"], sync_factory=sync_factory, entity_repo=sync_deps["entity_repo"], + access_broker=access_broker, temporal_workflow_service=sync_deps["temporal_workflow_service"], temporal_schedule_service=sync_deps["temporal_schedule_service"], usage_checker=usage_checker, diff --git a/backend/airweave/platform/access_control/broker.py b/backend/airweave/domains/access_control/broker.py similarity index 55% rename from backend/airweave/platform/access_control/broker.py rename to backend/airweave/domains/access_control/broker.py index 918a62ff4..b0d962d17 100644 --- a/backend/airweave/platform/access_control/broker.py +++ b/backend/airweave/domains/access_control/broker.py @@ -30,26 +30,13 @@ async def resolve_access_context( Note: SharePoint uses /transitivemembers so group expansion happens server-side. Other sources may store group-group tuples that need recursive expansion here. - - Args: - db: Database session - user_principal: User principal (username or identifier) - organization_id: Organization ID - - Returns: - AccessContext with fully expanded principals """ - # Query direct user-group memberships (member_type="user") memberships = await self._acl_repo.get_by_member( db=db, member_id=user_principal, member_type="user", organization_id=organization_id ) - # Build principals user_principals = [f"user:{user_principal}"] - # Recursively expand group-to-group relationships (if any exist) - # For SharePoint: no group-group tuples exist (uses /transitivemembers) - # For other sources: this handles nested group expansion all_groups = await self._expand_group_memberships( db=db, group_ids=[m.group_id for m in memberships], organization_id=organization_id ) @@ -69,31 +56,9 @@ async def resolve_access_context_for_collection( ) -> Optional[AccessContext]: """Resolve user's access context scoped to a collection's source connections. - This method only considers group memberships from source connections that belong - to the specified collection, enabling collection-scoped access control. - - IMPORTANT: Returns None if the collection has no sources with access control - support. This allows the search layer to skip filtering entirely for collections - that only contain sources like Slack, Asana, etc. - - Steps: - 1. Check if collection has any sources with access control - 2. If no AC sources, return None (no filtering needed) - 3. Query database for user's group memberships within the collection - 4. Recursively expand group-to-group relationships (if any) - 5. Build AccessContext with user + all expanded group principals - - Args: - db: Database session - user_principal: User principal (username or identifier) - readable_collection_id: Collection readable_id (string) to scope the access context - organization_id: Organization ID - - Returns: - AccessContext with fully expanded principals scoped to collection, - or None if collection has no access-control-enabled sources + Returns None if the collection has no sources with access control + support, allowing the search layer to skip filtering entirely. """ - # Check if collection has any sources with access control has_ac_sources = await self._collection_has_ac_sources( db=db, readable_collection_id=readable_collection_id, @@ -101,10 +66,8 @@ async def resolve_access_context_for_collection( ) if not has_ac_sources: - # No access control sources in collection → skip filtering return None - # Query user-group memberships scoped to collection (member_type="user") memberships = await self._acl_repo.get_by_member_and_collection( db=db, member_id=user_principal, @@ -113,11 +76,8 @@ async def resolve_access_context_for_collection( organization_id=organization_id, ) - # Build principals user_principals = [f"user:{user_principal}"] - # Recursively expand group-to-group relationships (if any exist) - # Note: Group expansion is still organization-wide, not collection-scoped all_groups = await self._expand_group_memberships( db=db, group_ids=[m.group_id for m in memberships], organization_id=organization_id ) @@ -134,26 +94,12 @@ async def _collection_has_ac_sources( readable_collection_id: str, organization_id: UUID, ) -> bool: - """Check if a collection has any sources with access control enabled. - - This queries the access_control_membership table to see if there are - any memberships for this collection. If there are memberships, at least - one source in the collection supports access control. - - Args: - db: Database session - readable_collection_id: Collection readable_id - organization_id: Organization ID - - Returns: - True if collection has at least one AC-enabled source - """ + """Check if a collection has any sources with access control enabled.""" from sqlalchemy import exists, select from airweave.models.access_control_membership import AccessControlMembership from airweave.models.source_connection import SourceConnection - # Check if any memberships exist for source connections in this collection stmt = select( exists( select(AccessControlMembership.id) @@ -176,23 +122,12 @@ async def _expand_group_memberships( ) -> Set[str]: """Recursively expand group memberships to handle nested groups. - For sources that store group-to-group relationships (e.g., Google Drive), - this recursively expands nested groups via CRUD layer. For SharePoint, - /transitivemembers handles this server-side, so no group-group tuples exist. - - Args: - db: Database session - group_ids: List of initial group IDs - organization_id: Organization ID - - Returns: - Set of all group IDs (direct + transitive) + Max depth of 10 to prevent infinite loops from circular group references. """ all_groups = set(group_ids) to_process = set(group_ids) visited = set() - # Recursively expand (max depth: 10 to prevent infinite loops) max_depth = 10 depth = 0 @@ -202,12 +137,10 @@ async def _expand_group_memberships( continue visited.add(current_group) - # Query for group-to-group memberships via CRUD layer (member_type="group") nested_memberships = await self._acl_repo.get_by_member( db=db, member_id=current_group, member_type="group", organization_id=organization_id ) - # Add parent groups and queue for processing for m in nested_memberships: if m.group_id not in all_groups: all_groups.add(m.group_id) @@ -220,45 +153,17 @@ async def _expand_group_memberships( def check_entity_access( self, entity_access: Optional[AccessControl], access_context: Optional[AccessContext] ) -> bool: - """Check if user can access entity based on access control. - - Args: - entity_access: Entity's AccessControl field (entity.access), may be None - access_context: User's AccessContext (from resolve_access_context), may be None - - Returns: - True if user has access to the entity: - - True if entity_access is None (no AC = public for non-AC sources) - - True if entity_access.is_public is True - - True if access_context is None (no AC context = no filtering) - - True if any of user's principals match entity.access.viewers - - False otherwise - """ - # No access control on entity = visible to everyone (non-AC source) + """Check if user can access entity based on access control.""" if entity_access is None: return True - # Public entity = visible to everyone if entity_access.is_public: return True - # No access context = no filtering (collection has no AC sources) if access_context is None: return True - # No viewers specified = visible to everyone (legacy behavior) if not entity_access.viewers: return True - # Check if any principal matches return bool(access_context.all_principals & set(entity_access.viewers)) - - -def _default_access_broker() -> "AccessBroker": - """Create a default AccessBroker backed by the real repository.""" - from airweave.domains.access_control.repository import AccessControlMembershipRepository - - return AccessBroker(acl_repo=AccessControlMembershipRepository()) - - -access_broker = _default_access_broker() diff --git a/backend/airweave/domains/access_control/fakes/broker.py b/backend/airweave/domains/access_control/fakes/broker.py new file mode 100644 index 000000000..63d23700b --- /dev/null +++ b/backend/airweave/domains/access_control/fakes/broker.py @@ -0,0 +1,60 @@ +"""Fake access broker for testing.""" + +from typing import Optional +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from airweave.platform.access_control.schemas import AccessContext +from airweave.platform.entities._base import AccessControl + + +class FakeAccessBroker: + """In-memory fake for AccessBrokerProtocol. + + Returns a simple access context with just the user principal. + Override `_access_context` to customize resolution behavior in tests. + """ + + def __init__(self) -> None: + self._access_context: Optional[AccessContext] = None + + async def resolve_access_context( + self, + db: AsyncSession, + user_principal: str, + organization_id: UUID, + ) -> AccessContext: + if self._access_context is not None: + return self._access_context + return AccessContext( + user_principal=user_principal, + user_principals=[f"user:{user_principal}"], + group_principals=[], + ) + + async def resolve_access_context_for_collection( + self, + db: AsyncSession, + user_principal: str, + readable_collection_id: str, + organization_id: UUID, + ) -> Optional[AccessContext]: + if self._access_context is not None: + return self._access_context + return None + + def check_entity_access( + self, + entity_access: Optional[AccessControl], + access_context: Optional[AccessContext], + ) -> bool: + if entity_access is None: + return True + if entity_access.is_public: + return True + if access_context is None: + return True + if not entity_access.viewers: + return True + return bool(access_context.all_principals & set(entity_access.viewers)) diff --git a/backend/airweave/domains/access_control/protocols.py b/backend/airweave/domains/access_control/protocols.py index cd8381ed7..dc0d3e53f 100644 --- a/backend/airweave/domains/access_control/protocols.py +++ b/backend/airweave/domains/access_control/protocols.py @@ -1,11 +1,13 @@ """Protocols for the access control domain.""" -from typing import List, Protocol +from typing import List, Optional, Protocol from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from airweave.models.access_control_membership import AccessControlMembership +from airweave.platform.access_control.schemas import AccessContext +from airweave.platform.entities._base import AccessControl class AccessControlMembershipRepositoryProtocol(Protocol): @@ -118,3 +120,34 @@ async def delete_by_source_connection( ) -> int: """Delete all memberships for a source connection.""" ... + + +class AccessBrokerProtocol(Protocol): + """Resolves user access context by expanding group memberships.""" + + async def resolve_access_context( + self, + db: AsyncSession, + user_principal: str, + organization_id: UUID, + ) -> AccessContext: + """Resolve user's access context by expanding group memberships.""" + ... + + async def resolve_access_context_for_collection( + self, + db: AsyncSession, + user_principal: str, + readable_collection_id: str, + organization_id: UUID, + ) -> Optional[AccessContext]: + """Resolve user's access context scoped to a collection.""" + ... + + def check_entity_access( + self, + entity_access: Optional[AccessControl], + access_context: Optional[AccessContext], + ) -> bool: + """Check if user can access entity based on access control.""" + ... diff --git a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py index 35f911eaa..a3951148c 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/access_control_pipeline.py @@ -13,9 +13,11 @@ from airweave.db.session import get_db_context from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol -from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher -from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker +from airweave.domains.sync_pipeline.protocols import ( + ACActionDispatcherProtocol, + ACActionResolverProtocol, +) from airweave.platform.access_control.schemas import ( ACLChangeType, MembershipTuple, @@ -32,8 +34,8 @@ class AccessControlPipeline: def __init__( self, - resolver: ACActionResolver, - dispatcher: ACActionDispatcher, + resolver: ACActionResolverProtocol, + dispatcher: ACActionDispatcherProtocol, tracker: ACLMembershipTracker, acl_repo: AccessControlMembershipRepositoryProtocol, ) -> None: diff --git a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py index 290423713..04ef5f571 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py @@ -5,10 +5,9 @@ """ import asyncio -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch @@ -23,25 +22,23 @@ class EntityActionDispatcher: Implements all-or-nothing semantics: - Destination handlers (Qdrant, RawData) run concurrently - If ANY destination handler fails, SyncFailureError bubbles up - - PostgreSQL metadata handler runs ONLY AFTER all destination handlers succeed + - Metadata handler runs ONLY AFTER all destination handlers succeed - This ensures consistency between vector stores and metadata Execution Order: - 1. All destination handlers (non-Postgres) execute concurrently - 2. If all succeed → PostgreSQL metadata handler executes - 3. If any fails → SyncFailureError, no Postgres writes + 1. All destination handlers execute concurrently + 2. If all succeed → metadata handler executes + 3. If any fails → SyncFailureError, no metadata writes """ - def __init__(self, handlers: List[EntityActionHandler]): - """Initialize with handler list, separating Postgres from destinations.""" - self._destination_handlers: List[EntityActionHandler] = [] - self._postgres_handler: EntityPostgresHandler | None = None - - for handler in handlers: - if isinstance(handler, EntityPostgresHandler): - self._postgres_handler = handler - else: - self._destination_handlers.append(handler) + def __init__( + self, + destination_handlers: List[EntityActionHandler], + metadata_handler: Optional[EntityActionHandler] = None, + ): + """Initialize with destination handlers and optional metadata handler.""" + self._destination_handlers = destination_handlers + self._postgres_handler = metadata_handler # ------------------------------------------------------------------------- # Public API diff --git a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py index f2ae9c1a3..8b0206fa7 100644 --- a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py +++ b/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py @@ -33,8 +33,13 @@ def build( logger: Optional[ContextualLogger] = None, ) -> EntityActionDispatcher: """Build a dispatcher with all configured handlers.""" - handlers = self._build_handlers(destinations, execution_config, logger) - return EntityActionDispatcher(handlers=handlers) + destination_handlers, metadata_handler = self._build_handlers( + destinations, execution_config, logger + ) + return EntityActionDispatcher( + destination_handlers=destination_handlers, + metadata_handler=metadata_handler, + ) def build_for_cleanup( self, @@ -49,7 +54,7 @@ def _build_handlers( destinations: List[BaseDestination], execution_config: Optional[SyncConfig], logger: Optional[ContextualLogger], - ) -> List[EntityActionHandler]: + ) -> tuple[List[EntityActionHandler], Optional[EntityActionHandler]]: enable_vector = ( execution_config.handlers.enable_vector_handlers if execution_config else True ) @@ -58,16 +63,17 @@ def _build_handlers( execution_config.handlers.enable_postgres_handler if execution_config else True ) - handlers: List[EntityActionHandler] = [] + destination_handlers: List[EntityActionHandler] = [] + metadata_handler: Optional[EntityActionHandler] = None - self._add_destination_handler(handlers, destinations, enable_vector, logger) - self._add_arf_handler(handlers, enable_arf, logger) - self._add_postgres_handler(handlers, enable_postgres, logger) + self._add_destination_handler(destination_handlers, destinations, enable_vector, logger) + self._add_arf_handler(destination_handlers, enable_arf, logger) + metadata_handler = self._build_postgres_handler(enable_postgres, logger) - if not handlers and logger: + if not destination_handlers and not metadata_handler and logger: logger.warning("No handlers created - sync will fetch entities but not persist them") - return handlers + return destination_handlers, metadata_handler def _add_destination_handler( self, @@ -105,15 +111,16 @@ def _add_arf_handler( elif logger: logger.info("Skipping ArfHandler (disabled by execution_config)") - def _add_postgres_handler( + def _build_postgres_handler( self, - handlers: List[EntityActionHandler], enabled: bool, logger: Optional[ContextualLogger], - ) -> None: + ) -> Optional[EntityActionHandler]: if enabled: - handlers.append(EntityPostgresHandler(entity_repo=self._entity_repo)) + handler = EntityPostgresHandler(entity_repo=self._entity_repo) if logger: logger.debug("Added EntityPostgresHandler") - elif logger: + return handler + if logger: logger.info("Skipping EntityPostgresHandler (disabled by execution_config)") + return None diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index 44ba61f98..adb8e1ce5 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -18,6 +18,7 @@ from airweave import schemas from airweave.core.context import BaseContext +from airweave.core.exceptions import NotFoundException from airweave.core.logging import LoggerConfigurator, logger from airweave.core.protocols.event_bus import EventBus from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol @@ -28,8 +29,11 @@ from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.builders import SyncContextBuilder +from airweave.domains.sync_pipeline.builders.destinations import DestinationsContextBuilder +from airweave.domains.sync_pipeline.builders.source import SourceContextBuilder from airweave.domains.sync_pipeline.builders.tracking import TrackingContextBuilder from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder +from airweave.domains.sync_pipeline.contexts.infra import InfraContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime from airweave.domains.sync_pipeline.entity_dispatcher_builder import EntityDispatcherBuilder from airweave.domains.sync_pipeline.handlers import ACPostgresHandler @@ -101,8 +105,6 @@ async def create_orchestrator( # Direct repo call — replaces SyncContextBuilder -> SourceContextBuilder chain sc = await self._sc_repo.get_by_sync_id(db, sync_id=sync.id, ctx=ctx) if not sc: - from airweave.core.exceptions import NotFoundException - raise NotFoundException(f"Source connection record not found for sync {sync.id}") source_connection_id = sc.id @@ -221,9 +223,6 @@ async def create_orchestrator( @staticmethod async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_config): """Build source and cursor. Returns (source, cursor) tuple.""" - from airweave.domains.sync_pipeline.builders.source import SourceContextBuilder - from airweave.domains.sync_pipeline.contexts.infra import InfraContext - sync_logger = LoggerConfigurator.configure_logger( "airweave.platform.sync.source_build", dimensions={ @@ -246,8 +245,6 @@ async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_conf @staticmethod async def _build_destinations(db, sync, collection, ctx, execution_config): """Build destinations and entity map. Returns (destinations, entity_map) tuple.""" - from airweave.domains.sync_pipeline.builders.destinations import DestinationsContextBuilder - dest_logger = LoggerConfigurator.configure_logger( "airweave.platform.sync.dest_build", dimensions={ diff --git a/backend/airweave/domains/sync_pipeline/protocols.py b/backend/airweave/domains/sync_pipeline/protocols.py index 87cb7c38a..edcffe683 100644 --- a/backend/airweave/domains/sync_pipeline/protocols.py +++ b/backend/airweave/domains/sync_pipeline/protocols.py @@ -7,7 +7,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas +from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch +from airweave.platform.access_control.schemas import MembershipTuple from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: @@ -91,6 +93,30 @@ async def cleanup_temp_files(self, sync_context: SyncContext, runtime: SyncRunti ... +class ACActionResolverProtocol(Protocol): + """Resolves membership tuples to action objects.""" + + async def resolve( + self, + memberships: List[MembershipTuple], + sync_context: SyncContext, + ) -> ACActionBatch: + """Resolve memberships to actions.""" + ... + + +class ACActionDispatcherProtocol(Protocol): + """Dispatches resolved AC membership actions to handlers.""" + + async def dispatch( + self, + batch: ACActionBatch, + sync_context: SyncContext, + ) -> int: + """Dispatch action batch to all handlers.""" + ... + + class SyncFactoryProtocol(Protocol): """Builds a SyncOrchestrator for a given sync run.""" diff --git a/backend/airweave/platform/access_control/__init__.py b/backend/airweave/platform/access_control/__init__.py index a32ab46df..c72868f79 100644 --- a/backend/airweave/platform/access_control/__init__.py +++ b/backend/airweave/platform/access_control/__init__.py @@ -1,6 +1,5 @@ """Access control module for permission resolution and filtering.""" -from .broker import AccessBroker, access_broker from .schemas import AccessContext, MembershipTuple -__all__ = ["AccessBroker", "access_broker", "AccessContext", "MembershipTuple"] +__all__ = ["AccessContext", "MembershipTuple"] diff --git a/backend/airweave/search/factory.py b/backend/airweave/search/factory.py index de091a2ce..222b08290 100644 --- a/backend/airweave/search/factory.py +++ b/backend/airweave/search/factory.py @@ -352,10 +352,12 @@ def _build_operations( and has_vector_sources ) if has_acl_context: + assert _container_module.container is not None access_control_op = AccessControlFilter( db=db, user_email=acl_user_email, organization_id=acl_org_id, + access_broker=_container_module.container.access_broker, ) if user_principal_override: ctx.logger.info( diff --git a/backend/airweave/search/operations/access_control_filter.py b/backend/airweave/search/operations/access_control_filter.py index f6bc703ea..df3b0625b 100644 --- a/backend/airweave/search/operations/access_control_filter.py +++ b/backend/airweave/search/operations/access_control_filter.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave.api.context import ApiContext -from airweave.platform.access_control.broker import access_broker +from airweave.domains.access_control.protocols import AccessBrokerProtocol from airweave.search.context import SearchContext from ._base import SearchOperation @@ -45,17 +45,13 @@ def __init__( db: AsyncSession, user_email: str, organization_id: UUID, + access_broker: AccessBrokerProtocol, ) -> None: - """Initialize with database session and user info. - - Args: - db: Database session for AccessBroker queries - user_email: User's email for principal resolution - organization_id: Organization ID for scoped queries - """ + """Initialize with database session, user info, and access broker.""" self.db = db self.user_email = user_email self.organization_id = organization_id + self._access_broker = access_broker def depends_on(self) -> List[str]: """No dependencies - runs early in the pipeline.""" @@ -75,7 +71,7 @@ async def execute( # Resolve access context for this collection # Returns None if collection has no AC sources (skip filtering) - access_context = await access_broker.resolve_access_context_for_collection( + access_context = await self._access_broker.resolve_access_context_for_collection( db=self.db, user_principal=self.user_email, readable_collection_id=context.readable_collection_id, diff --git a/backend/airweave/search/service.py b/backend/airweave/search/service.py index 236c7073b..d2e693e93 100644 --- a/backend/airweave/search/service.py +++ b/backend/airweave/search/service.py @@ -259,8 +259,11 @@ async def search_as_user( start_time = time.monotonic() # Get collection without organization filtering + import airweave.core.container as _container_module from airweave.models.collection import Collection - from airweave.platform.access_control.broker import access_broker + + assert _container_module.container is not None + _access_broker = _container_module.container.access_broker result = await db.execute( sa_select(Collection).where(Collection.readable_id == readable_collection_id) @@ -276,7 +279,7 @@ async def search_as_user( ) # Resolve access context for the specified user - access_context = await access_broker.resolve_access_context_for_collection( + access_context = await _access_broker.resolve_access_context_for_collection( db=db, user_principal=user_principal, readable_collection_id=readable_collection_id, diff --git a/backend/conftest.py b/backend/conftest.py index cfd17d47f..e7372773d 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -401,6 +401,14 @@ def fake_entity_repo(): return FakeEntityRepository() +@pytest.fixture +def fake_access_broker(): + """Fake AccessBroker.""" + from airweave.domains.access_control.fakes.broker import FakeAccessBroker + + return FakeAccessBroker() + + @pytest.fixture def fake_billing_webhook(): """Fake BillingWebhookProcessor.""" @@ -643,6 +651,7 @@ def test_container( fake_selection_repo, fake_sync_factory, fake_entity_repo, + fake_access_broker, ): """A Container with all dependencies replaced by fakes. @@ -714,4 +723,5 @@ def test_container( user_service=fake_user_service, sync_factory=fake_sync_factory, entity_repo=fake_entity_repo, + access_broker=fake_access_broker, ) diff --git a/backend/tests/unit/api/test_admin_user_principals.py b/backend/tests/unit/api/test_admin_user_principals.py index 48084b642..544a028ab 100644 --- a/backend/tests/unit/api/test_admin_user_principals.py +++ b/backend/tests/unit/api/test_admin_user_principals.py @@ -53,21 +53,19 @@ async def test_returns_principals_for_user(self, mock_db, mock_ctx): collection_repo = _collection_repo("test-collection", mock_ctx.organization_id) - with patch( - "airweave.api.v1.endpoints.admin._require_admin_permission" - ), patch( - "airweave.platform.access_control.broker.access_broker" - ) as mock_broker: - mock_broker.resolve_access_context_for_collection = AsyncMock( - return_value=fake_access_ctx - ) + mock_broker = MagicMock() + mock_broker.resolve_access_context_for_collection = AsyncMock( + return_value=fake_access_ctx + ) + with patch("airweave.api.v1.endpoints.admin._require_admin_permission"): result = await admin_get_user_principals( readable_id="test-collection", user_principal="sp_admin", db=mock_db, ctx=mock_ctx, collection_repo=collection_repo, + access_broker=mock_broker, ) assert "user:sp_admin" in result @@ -79,19 +77,17 @@ async def test_returns_empty_when_no_access_context(self, mock_db, mock_ctx): """Returns empty list when access broker returns None.""" collection_repo = _collection_repo("test-collection", mock_ctx.organization_id) - with patch( - "airweave.api.v1.endpoints.admin._require_admin_permission" - ), patch( - "airweave.platform.access_control.broker.access_broker" - ) as mock_broker: - mock_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) + mock_broker = MagicMock() + mock_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) + with patch("airweave.api.v1.endpoints.admin._require_admin_permission"): result = await admin_get_user_principals( readable_id="test-collection", user_principal="unknown_user", db=mock_db, ctx=mock_ctx, collection_repo=collection_repo, + access_broker=mock_broker, ) assert result == [] diff --git a/backend/tests/unit/domains/access_control/__init__.py b/backend/tests/unit/domains/access_control/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/unit/platform/access_control/test_broker.py b/backend/tests/unit/domains/access_control/test_broker.py similarity index 99% rename from backend/tests/unit/platform/access_control/test_broker.py rename to backend/tests/unit/domains/access_control/test_broker.py index 602362851..5f014bfd4 100644 --- a/backend/tests/unit/platform/access_control/test_broker.py +++ b/backend/tests/unit/domains/access_control/test_broker.py @@ -1,10 +1,11 @@ """Unit tests for AccessBroker.""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from airweave.platform.access_control.broker import AccessBroker +import pytest + +from airweave.domains.access_control.broker import AccessBroker from airweave.platform.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl diff --git a/backend/tests/unit/search/operations/test_access_control_filter.py b/backend/tests/unit/search/operations/test_access_control_filter.py index 6cc02b1b3..46160140f 100644 --- a/backend/tests/unit/search/operations/test_access_control_filter.py +++ b/backend/tests/unit/search/operations/test_access_control_filter.py @@ -1,9 +1,10 @@ """Unit tests for AccessControlFilter operation.""" -import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 +import pytest + from airweave.platform.access_control.schemas import AccessContext from airweave.search.operations.access_control_filter import AccessControlFilter from airweave.search.state import SearchState @@ -21,6 +22,12 @@ def mock_db(): return AsyncMock() +@pytest.fixture +def mock_access_broker(): + """Mock access broker.""" + return MagicMock() + + @pytest.fixture def mock_context(): """Mock SearchContext.""" @@ -46,128 +53,132 @@ async def test_execute_resolves_access_context_for_user( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that execute resolves access context for the user.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - # Setup mock access context - access_context = AccessContext( - user_principal="john@acme.com", - user_principals=["user:john@acme.com"], - group_principals=["group:sp:engineering"], - ) - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=access_context) - - # Create operation and execute - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) - - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) - - # Verify access broker was called - mock_access_broker.resolve_access_context_for_collection.assert_called_once_with( - db=mock_db, - user_principal="john@acme.com", - readable_collection_id="test-collection", - organization_id=organization_id, - ) + access_context = AccessContext( + user_principal="john@acme.com", + user_principals=["user:john@acme.com"], + group_principals=["group:sp:engineering"], + ) + mock_access_broker.resolve_access_context_for_collection = AsyncMock( + return_value=access_context + ) + + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) + + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) + + mock_access_broker.resolve_access_context_for_collection.assert_called_once_with( + db=mock_db, + user_principal="john@acme.com", + readable_collection_id="test-collection", + organization_id=organization_id, + ) async def test_execute_builds_filter_with_user_principals( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that execute builds filter with resolved principals.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - # Setup mock access context - access_context = AccessContext( - user_principal="john@acme.com", - user_principals=["user:john@acme.com"], - group_principals=["group:sp:engineering", "group:ad:frontend"], - ) - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=access_context) - - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) - - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) - - # Verify filter was built with principals - assert state.filter is not None - assert "should" in state.filter - # Should have public OR viewers conditions - assert len(state.filter["should"]) == 2 + access_context = AccessContext( + user_principal="john@acme.com", + user_principals=["user:john@acme.com"], + group_principals=["group:sp:engineering", "group:ad:frontend"], + ) + mock_access_broker.resolve_access_context_for_collection = AsyncMock( + return_value=access_context + ) + + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) + + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) + + assert state.filter is not None + assert "should" in state.filter + assert len(state.filter["should"]) == 2 async def test_execute_writes_filter_to_state( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that execute writes filter to state.filter.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - access_context = AccessContext( - user_principal="john@acme.com", - user_principals=["user:john@acme.com"], - group_principals=["group:sp:engineering"], - ) - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=access_context) - - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) - - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) - - # Verify state.filter was set - assert state.filter is not None - assert isinstance(state.filter, dict) + access_context = AccessContext( + user_principal="john@acme.com", + user_principals=["user:john@acme.com"], + group_principals=["group:sp:engineering"], + ) + mock_access_broker.resolve_access_context_for_collection = AsyncMock( + return_value=access_context + ) + + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) + + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) + + assert state.filter is not None + assert isinstance(state.filter, dict) async def test_execute_sets_access_principals_in_state( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that execute sets access_principals in state.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - access_context = AccessContext( - user_principal="john@acme.com", - user_principals=["user:john@acme.com"], - group_principals=["group:sp:engineering"], - ) - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=access_context) - - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) - - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) - - # Verify access_principals was set - assert state.access_principals is not None - assert len(state.access_principals) == 2 # user + 1 group - assert "user:john@acme.com" in state.access_principals - assert "group:sp:engineering" in state.access_principals + access_context = AccessContext( + user_principal="john@acme.com", + user_principals=["user:john@acme.com"], + group_principals=["group:sp:engineering"], + ) + mock_access_broker.resolve_access_context_for_collection = AsyncMock( + return_value=access_context + ) + + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) + + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) + + assert state.access_principals is not None + assert len(state.access_principals) == 2 + assert "user:john@acme.com" in state.access_principals + assert "group:sp:engineering" in state.access_principals @pytest.mark.asyncio @@ -178,96 +189,95 @@ async def test_execute_skips_filtering_when_no_ac_sources( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that filtering is skipped when collection has no AC sources.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - # Mock broker returns None (no AC sources) - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) + mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) - # Verify no filter was set - assert state.filter is None - assert state.access_principals is None + assert state.filter is None + assert state.access_principals is None async def test_execute_sets_access_principals_to_none_when_no_ac_sources( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test access_principals is None when no AC sources.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) + mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) - assert state.access_principals is None + assert state.access_principals is None async def test_execute_emits_skipped_event( self, mock_db, organization_id, + mock_access_broker, mock_context, mock_api_context, ): """Test that skipped event is emitted when no AC sources.""" - with patch("airweave.search.operations.access_control_filter.access_broker") as mock_access_broker: - mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) + mock_access_broker.resolve_access_context_for_collection = AsyncMock(return_value=None) - operation = AccessControlFilter( - db=mock_db, - user_email="john@acme.com", - organization_id=organization_id, - ) + operation = AccessControlFilter( + db=mock_db, + user_email="john@acme.com", + organization_id=organization_id, + access_broker=mock_access_broker, + ) - state = SearchState() - await operation.execute(mock_context, state, mock_api_context) + state = SearchState() + await operation.execute(mock_context, state, mock_api_context) - # Verify emitter was called with skipped event - mock_context.emitter.emit.assert_called_once() - call_args = mock_context.emitter.emit.call_args - assert call_args[0][0] == "access_control_skipped" + mock_context.emitter.emit.assert_called_once() + call_args = mock_context.emitter.emit.call_args + assert call_args[0][0] == "access_control_skipped" class TestAccessControlFilterBuildFilter: """Test filter building logic.""" def test_build_filter_includes_is_public_condition( - self, mock_db, organization_id + self, mock_db, organization_id, mock_access_broker ): """Test that filter includes is_public condition.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) principals = ["user:john@acme.com", "group:sp:engineering"] filter_result = operation._build_access_control_filter(principals) - # Should have OR condition with is_public assert "should" in filter_result conditions = filter_result["should"] - - # Find is_public condition + public_condition = next( (c for c in conditions if c.get("key") == "access.is_public"), None ) @@ -275,23 +285,22 @@ def test_build_filter_includes_is_public_condition( assert public_condition["match"]["value"] is True def test_build_filter_includes_viewers_any_condition( - self, mock_db, organization_id + self, mock_db, organization_id, mock_access_broker ): """Test that filter includes viewers any condition.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) principals = ["user:john@acme.com", "group:sp:engineering"] filter_result = operation._build_access_control_filter(principals) - # Should have OR condition with viewers assert "should" in filter_result conditions = filter_result["should"] - - # Find viewers condition + viewers_condition = next( (c for c in conditions if c.get("key") == "access.viewers"), None ) @@ -299,17 +308,19 @@ def test_build_filter_includes_viewers_any_condition( assert "any" in viewers_condition["match"] assert set(viewers_condition["match"]["any"]) == set(principals) - def test_build_filter_handles_empty_principals(self, mock_db, organization_id): + def test_build_filter_handles_empty_principals( + self, mock_db, organization_id, mock_access_broker + ): """Test filter building with empty principals list.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) filter_result = operation._build_access_control_filter([]) - # Should only have is_public condition assert "should" in filter_result conditions = filter_result["should"] assert len(conditions) == 1 @@ -319,12 +330,15 @@ def test_build_filter_handles_empty_principals(self, mock_db, organization_id): class TestAccessControlFilterMerging: """Test filter merging with existing filters.""" - def test_merge_combines_with_existing_filter(self, mock_db, organization_id): + def test_merge_combines_with_existing_filter( + self, mock_db, organization_id, mock_access_broker + ): """Test that AC filter merges with existing filter.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) ac_filter = {"should": [{"key": "access.is_public", "match": {"value": True}}]} @@ -332,18 +346,18 @@ def test_merge_combines_with_existing_filter(self, mock_db, organization_id): merged = operation._merge_with_existing_filter(ac_filter, existing_filter) - # Should create must condition with both filters assert "must" in merged assert len(merged["must"]) == 2 def test_merge_creates_must_condition_when_both_exist( - self, mock_db, organization_id + self, mock_db, organization_id, mock_access_broker ): """Test that merge creates must AND condition.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) ac_filter = {"should": [{"key": "access.is_public", "match": {"value": True}}]} @@ -356,13 +370,14 @@ def test_merge_creates_must_condition_when_both_exist( assert existing_filter in merged["must"] def test_merge_returns_new_filter_when_no_existing( - self, mock_db, organization_id + self, mock_db, organization_id, mock_access_broker ): """Test that merge returns AC filter when no existing filter.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) ac_filter = {"should": [{"key": "access.is_public", "match": {"value": True}}]} @@ -375,15 +390,17 @@ def test_merge_returns_new_filter_when_no_existing( class TestAccessControlFilterDependencies: """Test operation dependencies.""" - def test_depends_on_returns_empty_list(self, mock_db, organization_id): + def test_depends_on_returns_empty_list( + self, mock_db, organization_id, mock_access_broker + ): """Test that AccessControlFilter has no dependencies.""" operation = AccessControlFilter( db=mock_db, user_email="john@acme.com", organization_id=organization_id, + access_broker=mock_access_broker, ) dependencies = operation.depends_on() assert dependencies == [] - From 195968d27a4b187bd1bb27f395261f0023592703 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Fri, 13 Mar 2026 11:17:19 -0700 Subject: [PATCH 10/13] fix: restore SourceContextBuilder lazy import to break circular container init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit builders/source.py has a top-level container import, so importing it during core/container/__init__.py execution resolves the `container` name to the submodule instead of the variable — causing "module has no attribute 'source_lifecycle_service'" at runtime. --- backend/airweave/domains/sync_pipeline/factory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index adb8e1ce5..38b675890 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -30,7 +30,6 @@ from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.builders import SyncContextBuilder from airweave.domains.sync_pipeline.builders.destinations import DestinationsContextBuilder -from airweave.domains.sync_pipeline.builders.source import SourceContextBuilder from airweave.domains.sync_pipeline.builders.tracking import TrackingContextBuilder from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder from airweave.domains.sync_pipeline.contexts.infra import InfraContext @@ -223,6 +222,8 @@ async def create_orchestrator( @staticmethod async def _build_source(db, sync, sync_job, ctx, force_full_sync, execution_config): """Build source and cursor. Returns (source, cursor) tuple.""" + from airweave.domains.sync_pipeline.builders.source import SourceContextBuilder + sync_logger = LoggerConfigurator.configure_logger( "airweave.platform.sync.source_build", dimensions={ From 420a07c647004d2f8786e1c9aa89e443f068baa7 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Fri, 13 Mar 2026 12:08:01 -0700 Subject: [PATCH 11/13] refactor: move converters to domain with DI via ConverterRegistry Replace platform/converters singleton system with domains/converters/ domain. ConverterRegistry is built in the container factory and injected into TextualRepresentationBuilder and ChunkEmbedProcessor, eliminating initialize_converters() calls and lazy imports. --- backend/airweave/core/container/container.py | 4 + backend/airweave/core/container/factory.py | 5 +- .../airweave/domains/converters/__init__.py | 5 + .../{platform => domains}/converters/_base.py | 37 +-- .../converters/code.py} | 23 +- backend/airweave/domains/converters/docx.py | 15 + .../domains/converters/fakes/__init__.py | 0 .../domains/converters/fakes/registry.py | 30 ++ .../converters/html.py} | 22 +- backend/airweave/domains/converters/pdf.py | 21 ++ backend/airweave/domains/converters/pptx.py | 15 + .../airweave/domains/converters/protocols.py | 19 ++ .../airweave/domains/converters/registry.py | 86 +++++ .../converters/text_extractors/__init__.py | 13 + .../converters/text_extractors/docx.py | 37 +-- .../converters/text_extractors/pdf.py | 93 +----- .../converters/text_extractors/pptx.py | 41 +-- .../converters/txt.py} | 76 +---- .../converters/web.py} | 112 +------ .../converters/xlsx.py} | 53 +--- .../sync_pipeline/pipeline/text_builder.py | 109 +------ .../sync_pipeline/processors/chunk_embed.py | 28 +- .../sync_pipeline/tests/test_chunk_embed.py | 112 +++---- backend/airweave/main.py | 5 - .../airweave/platform/converters/__init__.py | 114 ------- .../platform/converters/docx_converter.py | 33 -- .../platform/converters/pdf_converter.py | 41 --- .../platform/converters/pptx_converter.py | 33 -- .../converters/text_extractors/__init__.py | 20 -- .../platform/ocr/mistral/converter.py | 2 +- .../platform/temporal/worker/__init__.py | 7 +- backend/conftest.py | 10 + .../tests/unit/domains/converters/__init__.py | 0 .../unit/domains/converters/test_code.py | 69 ++++ .../unit/domains/converters/test_html.py | 76 +++++ .../unit/domains/converters/test_registry.py | 36 +++ .../converters/test_txt.py} | 79 +---- .../unit/platform/converters/__init__.py | 2 - .../converters/test_code_converter.py | 240 -------------- .../converters/test_html_converter.py | 297 ------------------ .../converters/test_init_converters.py | 33 -- 41 files changed, 519 insertions(+), 1534 deletions(-) create mode 100644 backend/airweave/domains/converters/__init__.py rename backend/airweave/{platform => domains}/converters/_base.py (75%) rename backend/airweave/{platform/converters/code_converter.py => domains/converters/code.py} (74%) create mode 100644 backend/airweave/domains/converters/docx.py create mode 100644 backend/airweave/domains/converters/fakes/__init__.py create mode 100644 backend/airweave/domains/converters/fakes/registry.py rename backend/airweave/{platform/converters/html_converter.py => domains/converters/html.py} (78%) create mode 100644 backend/airweave/domains/converters/pdf.py create mode 100644 backend/airweave/domains/converters/pptx.py create mode 100644 backend/airweave/domains/converters/protocols.py create mode 100644 backend/airweave/domains/converters/registry.py create mode 100644 backend/airweave/domains/converters/text_extractors/__init__.py rename backend/airweave/{platform => domains}/converters/text_extractors/docx.py (67%) rename backend/airweave/{platform => domains}/converters/text_extractors/pdf.py (63%) rename backend/airweave/{platform => domains}/converters/text_extractors/pptx.py (64%) rename backend/airweave/{platform/converters/txt_converter.py => domains/converters/txt.py} (75%) rename backend/airweave/{platform/converters/web_converter.py => domains/converters/web.py} (58%) rename backend/airweave/{platform/converters/xlsx_converter.py => domains/converters/xlsx.py} (72%) delete mode 100644 backend/airweave/platform/converters/__init__.py delete mode 100644 backend/airweave/platform/converters/docx_converter.py delete mode 100644 backend/airweave/platform/converters/pdf_converter.py delete mode 100644 backend/airweave/platform/converters/pptx_converter.py delete mode 100644 backend/airweave/platform/converters/text_extractors/__init__.py create mode 100644 backend/tests/unit/domains/converters/__init__.py create mode 100644 backend/tests/unit/domains/converters/test_code.py create mode 100644 backend/tests/unit/domains/converters/test_html.py create mode 100644 backend/tests/unit/domains/converters/test_registry.py rename backend/tests/unit/{platform/converters/test_txt_converter.py => domains/converters/test_txt.py} (58%) delete mode 100644 backend/tests/unit/platform/converters/__init__.py delete mode 100644 backend/tests/unit/platform/converters/test_code_converter.py delete mode 100644 backend/tests/unit/platform/converters/test_html_converter.py delete mode 100644 backend/tests/unit/platform/converters/test_init_converters.py diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index dc8f664f6..242d0f25d 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -46,6 +46,7 @@ ) from airweave.domains.connect.protocols import ConnectServiceProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol +from airweave.domains.converters.protocols import ConverterRegistryProtocol from airweave.domains.credentials.protocols import IntegrationCredentialRepositoryProtocol from airweave.domains.embedders.protocols import ( DenseEmbedderProtocol, @@ -235,6 +236,9 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): # Connect domain service (session-based frontend integration flows) connect_service: ConnectServiceProtocol + # Converter registry (maps file extensions to converter instances) + converter_registry: ConverterRegistryProtocol + # OCR provider (with fallback chain + circuit breaking) # Optional: None when no OCR backend (Mistral/Docling) is configured ocr_provider: Optional[OcrProvider] = None diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index bc12e1206..284a856e2 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -56,6 +56,7 @@ VectorDbDeploymentMetadataRepository, ) from airweave.domains.connections.repository import ConnectionRepository +from airweave.domains.converters.registry import ConverterRegistry from airweave.domains.credentials.repository import IntegrationCredentialRepository from airweave.domains.embedders.config import ( DENSE_EMBEDDER, @@ -386,7 +387,8 @@ def create_container(settings: Settings) -> Container: # ----------------------------------------------------------------- acl_membership_repo = AccessControlMembershipRepository() access_broker = AccessBroker(acl_repo=acl_membership_repo) - chunk_embed_processor = ChunkEmbedProcessor() + converter_registry = ConverterRegistry(ocr_provider=ocr_provider) + chunk_embed_processor = ChunkEmbedProcessor(converter_registry=converter_registry) # ----------------------------------------------------------------- # Sync factory + service @@ -514,6 +516,7 @@ def create_container(settings: Settings) -> Container: sync_factory=sync_factory, entity_repo=sync_deps["entity_repo"], access_broker=access_broker, + converter_registry=converter_registry, temporal_workflow_service=sync_deps["temporal_workflow_service"], temporal_schedule_service=sync_deps["temporal_schedule_service"], usage_checker=usage_checker, diff --git a/backend/airweave/domains/converters/__init__.py b/backend/airweave/domains/converters/__init__.py new file mode 100644 index 000000000..aba3d5579 --- /dev/null +++ b/backend/airweave/domains/converters/__init__.py @@ -0,0 +1,5 @@ +"""Document and file converters domain. + +Provides converters for transforming files, URLs, and code into markdown text. +The ConverterRegistry is the main entry point, wired via the DI container. +""" diff --git a/backend/airweave/platform/converters/_base.py b/backend/airweave/domains/converters/_base.py similarity index 75% rename from backend/airweave/platform/converters/_base.py rename to backend/airweave/domains/converters/_base.py index 2d1a1d097..7be5f20fd 100644 --- a/backend/airweave/platform/converters/_base.py +++ b/backend/airweave/domains/converters/_base.py @@ -44,12 +44,6 @@ async def _try_extract(self, path: str) -> Optional[str]: """ def __init__(self, ocr_provider: Optional[OcrProvider] = None) -> None: - """Initialize the converter. - - Args: - ocr_provider: OCR provider for fallback. If ``None``, files that - cannot be text-extracted will return ``None``. - """ self._ocr_provider = ocr_provider @abstractmethod @@ -62,20 +56,7 @@ async def _try_extract(self, path: str) -> Optional[str]: @staticmethod def _try_read_as_text(path: str, max_probe_bytes: int = 8192) -> Optional[str]: - """Check if a file is actually plain text despite its extension. - - Reads a small probe of the file and checks if it decodes as valid UTF-8 - with a low ratio of control characters. This catches files that have - binary extensions (e.g. .docx, .pdf) but actually contain plain text -- - common with auto-generated test data or legacy systems. - - Args: - path: Path to the file. - max_probe_bytes: How many bytes to probe for text detection. - - Returns: - Full file content as string if it's valid text, None otherwise. - """ + """Check if a file is actually plain text despite its extension.""" try: with open(path, "rb") as f: probe = f.read(max_probe_bytes) @@ -83,23 +64,18 @@ def _try_read_as_text(path: str, max_probe_bytes: int = 8192) -> Optional[str]: if not probe: return None - # Try UTF-8 decode on the probe try: probe.decode("utf-8") except UnicodeDecodeError: return None - # Check for excessive control characters (binary indicator) - # Allow common whitespace: \n, \r, \t control_count = sum(1 for b in probe if b < 32 and b not in (9, 10, 13)) - if control_count / len(probe) > 0.05: # >5% control chars = binary + if control_count / len(probe) > 0.05: return None - # It's text -- read the full file with open(path, "r", encoding="utf-8") as f: content = f.read() - # Must have meaningful content if len(content.strip()) < 10: return None @@ -113,12 +89,6 @@ async def convert_batch(self, file_paths: List[str]) -> Dict[str, Optional[str]] For each file, calls :meth:`_try_extract`. If that returns content, uses it directly (0 API calls). Otherwise, batches the file for OCR. - - Args: - file_paths: Local file paths to convert. - - Returns: - Mapping of ``file_path -> markdown`` (``None`` on failure). """ results: Dict[str, Optional[str]] = {} needs_ocr: List[str] = [] @@ -131,8 +101,6 @@ async def convert_batch(self, file_paths: List[str]) -> Dict[str, Optional[str]] results[path] = markdown logger.debug(f"{name}: extracted via text layer") else: - # Before falling back to OCR, check if the file is actually - # plain text with a misleading extension (e.g. .docx containing text) text_content = self._try_read_as_text(path) if text_content: results[path] = text_content @@ -145,7 +113,6 @@ async def convert_batch(self, file_paths: List[str]) -> Dict[str, Optional[str]] needs_ocr.append(path) except Exception as exc: logger.warning(f"{name}: extraction error ({exc}), needs OCR") - # Same fallback check on extraction errors text_content = self._try_read_as_text(path) if text_content: results[path] = text_content diff --git a/backend/airweave/platform/converters/code_converter.py b/backend/airweave/domains/converters/code.py similarity index 74% rename from backend/airweave/platform/converters/code_converter.py rename to backend/airweave/domains/converters/code.py index 368452558..186b00be9 100644 --- a/backend/airweave/platform/converters/code_converter.py +++ b/backend/airweave/domains/converters/code.py @@ -6,35 +6,22 @@ import aiofiles from airweave.core.logging import logger -from airweave.platform.converters._base import BaseTextConverter +from airweave.domains.converters._base import BaseTextConverter class CodeConverter(BaseTextConverter): - """Converts code files to markdown code fences. - - Simple converter that wraps code content in markdown code fences - with appropriate language tags. No AI summarization - code-specific - embeddings will be used later for optimal retrieval. - """ + """Converts code files to markdown code fences.""" async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: - """Convert code files to markdown code fences. - - Args: - file_paths: List of code file paths - - Returns: - Dict mapping file_path -> markdown (code fence with language tag) - """ + """Convert code files to markdown code fences.""" logger.debug(f"Converting {len(file_paths)} code files to markdown...") results = {} - semaphore = asyncio.Semaphore(20) # Limit concurrent file reads + semaphore = asyncio.Semaphore(20) async def _convert_one(path: str): async with semaphore: try: - # Read raw bytes for encoding detection async with aiofiles.open(path, "rb") as f: raw_bytes = await f.read() @@ -43,7 +30,6 @@ async def _convert_one(path: str): results[path] = None return - # Try UTF-8 first (most common for code) try: code = raw_bytes.decode("utf-8") if "\ufffd" not in code: @@ -53,7 +39,6 @@ async def _convert_one(path: str): except UnicodeDecodeError: pass - # Fallback: decode with replace to detect corruption code = raw_bytes.decode("utf-8", errors="replace") replacement_count = code.count("\ufffd") diff --git a/backend/airweave/domains/converters/docx.py b/backend/airweave/domains/converters/docx.py new file mode 100644 index 000000000..31f76bf33 --- /dev/null +++ b/backend/airweave/domains/converters/docx.py @@ -0,0 +1,15 @@ +"""DOCX converter with hybrid text extraction + OCR fallback.""" + +from __future__ import annotations + +from typing import Optional + +from airweave.domains.converters._base import HybridDocumentConverter +from airweave.domains.converters.text_extractors.docx import extract_docx_text + + +class DocxConverter(HybridDocumentConverter): + """Converts DOCX files to markdown using text extraction with OCR fallback.""" + + async def _try_extract(self, path: str) -> Optional[str]: + return await extract_docx_text(path) diff --git a/backend/airweave/domains/converters/fakes/__init__.py b/backend/airweave/domains/converters/fakes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/domains/converters/fakes/registry.py b/backend/airweave/domains/converters/fakes/registry.py new file mode 100644 index 000000000..707112a4a --- /dev/null +++ b/backend/airweave/domains/converters/fakes/registry.py @@ -0,0 +1,30 @@ +"""Fake converter registry for testing.""" + +from __future__ import annotations + +from typing import Dict, List, Optional + +from airweave.domains.converters._base import BaseTextConverter + + +class _StubConverter(BaseTextConverter): + """Returns canned markdown for every file/URL.""" + + def __init__(self, text: str = "fake-markdown") -> None: + self._text = text + + async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: + return {p: self._text for p in file_paths} + + +class FakeConverterRegistry: + """In-memory registry returning stub converters for all lookups.""" + + def __init__(self, text: str = "fake-markdown") -> None: + self._stub = _StubConverter(text) + + def for_extension(self, ext: str) -> Optional[BaseTextConverter]: + return self._stub + + def for_web(self) -> BaseTextConverter: + return self._stub diff --git a/backend/airweave/platform/converters/html_converter.py b/backend/airweave/domains/converters/html.py similarity index 78% rename from backend/airweave/platform/converters/html_converter.py rename to backend/airweave/domains/converters/html.py index 0be1c7eee..e87f75f26 100644 --- a/backend/airweave/platform/converters/html_converter.py +++ b/backend/airweave/domains/converters/html.py @@ -4,26 +4,16 @@ from typing import Dict, List from airweave.core.logging import logger +from airweave.domains.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.converters._base import BaseTextConverter class HtmlConverter(BaseTextConverter): """Converts HTML files to markdown text using html-to-markdown.""" async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: - """Convert HTML files to markdown text. - - Args: - file_paths: List of file paths to convert - - Returns: - Dict mapping file_path -> markdown text content (None if failed) - - Raises: - EntityProcessingError: If html-to-markdown package not installed - """ + """Convert HTML files to markdown text.""" try: from html_to_markdown import convert except ImportError: @@ -36,28 +26,25 @@ async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: logger.info(f"Converting {len(file_paths)} HTML files to markdown...") results = {} - semaphore = asyncio.Semaphore(20) # Limit concurrent conversions + semaphore = asyncio.Semaphore(20) async def _convert_one(path: str): async with semaphore: try: def _convert(): - # Read raw bytes for encoding detection with open(path, "rb") as f: raw_bytes = f.read() if not raw_bytes: return None - # Try UTF-8 first try: html_content = raw_bytes.decode("utf-8") except UnicodeDecodeError: - # Fallback with replace to detect corruption html_content = raw_bytes.decode("utf-8", errors="replace") replacement_count = html_content.count("\ufffd") - if replacement_count > 100: # Lenient for HTML + if replacement_count > 100: raise EntityProcessingError( f"HTML contains excessive binary data " f"({replacement_count} replacement chars)" @@ -66,7 +53,6 @@ def _convert(): if not html_content.strip(): return None - # Convert to markdown using html-to-markdown (Rust-powered) markdown = convert(html_content) return markdown.strip() if markdown else None diff --git a/backend/airweave/domains/converters/pdf.py b/backend/airweave/domains/converters/pdf.py new file mode 100644 index 000000000..5a11d5c16 --- /dev/null +++ b/backend/airweave/domains/converters/pdf.py @@ -0,0 +1,21 @@ +"""PDF converter with hybrid text extraction + OCR fallback.""" + +from __future__ import annotations + +from typing import Optional + +from airweave.domains.converters._base import HybridDocumentConverter +from airweave.domains.converters.text_extractors.pdf import ( + extract_pdf_text, + text_to_markdown, +) + + +class PdfConverter(HybridDocumentConverter): + """Converts PDFs to markdown using text extraction with OCR fallback.""" + + async def _try_extract(self, path: str) -> Optional[str]: + extraction = await extract_pdf_text(path) + if extraction.fully_extracted and extraction.full_text: + return text_to_markdown(extraction.full_text) + return None diff --git a/backend/airweave/domains/converters/pptx.py b/backend/airweave/domains/converters/pptx.py new file mode 100644 index 000000000..9148be559 --- /dev/null +++ b/backend/airweave/domains/converters/pptx.py @@ -0,0 +1,15 @@ +"""PPTX converter with hybrid text extraction + OCR fallback.""" + +from __future__ import annotations + +from typing import Optional + +from airweave.domains.converters._base import HybridDocumentConverter +from airweave.domains.converters.text_extractors.pptx import extract_pptx_text + + +class PptxConverter(HybridDocumentConverter): + """Converts PPTX files to markdown using text extraction with OCR fallback.""" + + async def _try_extract(self, path: str) -> Optional[str]: + return await extract_pptx_text(path) diff --git a/backend/airweave/domains/converters/protocols.py b/backend/airweave/domains/converters/protocols.py new file mode 100644 index 000000000..c0f76b253 --- /dev/null +++ b/backend/airweave/domains/converters/protocols.py @@ -0,0 +1,19 @@ +"""Protocols for the converters domain.""" + +from __future__ import annotations + +from typing import Optional, Protocol + +from airweave.domains.converters._base import BaseTextConverter + + +class ConverterRegistryProtocol(Protocol): + """Registry that maps file extensions to converter instances.""" + + def for_extension(self, ext: str) -> Optional[BaseTextConverter]: + """Return the converter for a given file extension, or None.""" + ... + + def for_web(self) -> BaseTextConverter: + """Return the web converter.""" + ... diff --git a/backend/airweave/domains/converters/registry.py b/backend/airweave/domains/converters/registry.py new file mode 100644 index 000000000..ef6f03f8b --- /dev/null +++ b/backend/airweave/domains/converters/registry.py @@ -0,0 +1,86 @@ +"""Converter registry — maps file extensions to converter instances.""" + +from __future__ import annotations + +from typing import Dict, Optional + +from airweave.core.protocols.ocr import OcrProvider +from airweave.domains.converters._base import BaseTextConverter +from airweave.domains.converters.code import CodeConverter +from airweave.domains.converters.docx import DocxConverter +from airweave.domains.converters.html import HtmlConverter +from airweave.domains.converters.pdf import PdfConverter +from airweave.domains.converters.pptx import PptxConverter +from airweave.domains.converters.txt import TxtConverter +from airweave.domains.converters.web import WebConverter +from airweave.domains.converters.xlsx import XlsxConverter + + +class ConverterRegistry: + """Concrete registry that creates and owns all converter instances. + + Built once by the container factory with the resolved OCR provider. + """ + + def __init__(self, ocr_provider: Optional[OcrProvider] = None) -> None: + """Build all converter instances and the extension mapping.""" + pdf = PdfConverter(ocr_provider=ocr_provider) + docx = DocxConverter(ocr_provider=ocr_provider) + pptx = PptxConverter(ocr_provider=ocr_provider) + html = HtmlConverter() + txt = TxtConverter() + xlsx = XlsxConverter() + code = CodeConverter() + self._web = WebConverter() + + self._extension_map: Dict[str, BaseTextConverter] = { + # Documents — text extraction + OCR fallback + ".pdf": pdf, + ".docx": docx, + ".pptx": pptx, + # Images — direct OCR (ocr_provider itself implements convert_batch) + ".jpg": ocr_provider, + ".jpeg": ocr_provider, + ".png": ocr_provider, + # Spreadsheets + ".xlsx": xlsx, + # HTML + ".html": html, + ".htm": html, + # Text / structured text + ".txt": txt, + ".json": txt, + ".xml": txt, + ".md": txt, + ".yaml": txt, + ".yml": txt, + ".toml": txt, + # Code + ".py": code, + ".js": code, + ".ts": code, + ".tsx": code, + ".jsx": code, + ".java": code, + ".cpp": code, + ".c": code, + ".h": code, + ".hpp": code, + ".go": code, + ".rs": code, + ".rb": code, + ".php": code, + ".swift": code, + ".kt": code, + ".kts": code, + ".tf": code, + ".tfvars": code, + } + + def for_extension(self, ext: str) -> Optional[BaseTextConverter]: + """Return the converter for a given file extension, or None.""" + return self._extension_map.get(ext) + + def for_web(self) -> BaseTextConverter: + """Return the web converter.""" + return self._web diff --git a/backend/airweave/domains/converters/text_extractors/__init__.py b/backend/airweave/domains/converters/text_extractors/__init__.py new file mode 100644 index 000000000..4aa21eca5 --- /dev/null +++ b/backend/airweave/domains/converters/text_extractors/__init__.py @@ -0,0 +1,13 @@ +"""Text extraction utilities for various document formats.""" + +from .docx import extract_docx_text +from .pdf import PdfExtractionResult, extract_pdf_text, text_to_markdown +from .pptx import extract_pptx_text + +__all__ = [ + "PdfExtractionResult", + "extract_pdf_text", + "text_to_markdown", + "extract_docx_text", + "extract_pptx_text", +] diff --git a/backend/airweave/platform/converters/text_extractors/docx.py b/backend/airweave/domains/converters/text_extractors/docx.py similarity index 67% rename from backend/airweave/platform/converters/text_extractors/docx.py rename to backend/airweave/domains/converters/text_extractors/docx.py index b4f23611a..2e1de94a7 100644 --- a/backend/airweave/platform/converters/text_extractors/docx.py +++ b/backend/airweave/domains/converters/text_extractors/docx.py @@ -1,9 +1,4 @@ -"""Direct text extraction from DOCX files using python-docx. - -Extracts paragraph text and basic structure (headings, lists) without any -API calls. If the DOCX has extractable text, this is orders of magnitude -faster and cheaper than sending it through OCR. -""" +"""Direct text extraction from DOCX files using python-docx.""" from __future__ import annotations @@ -14,10 +9,8 @@ from airweave.core.logging import logger from airweave.domains.sync_pipeline.exceptions import SyncFailureError -# Minimum total characters to consider the extraction successful. MIN_TOTAL_CHARS = 50 -# Heading style → markdown prefix mapping (checked in order). _HEADING_MAP = ( ("heading 1", "# "), ("heading 2", "## "), @@ -27,14 +20,6 @@ def _format_paragraph(para: Any) -> Optional[str]: - """Convert a single DOCX paragraph to a markdown line. - - Args: - para: A ``docx.text.paragraph.Paragraph`` instance. - - Returns: - Markdown string or ``None`` if the paragraph is empty. - """ text = para.text.strip() if not text: return None @@ -52,14 +37,6 @@ def _format_paragraph(para: Any) -> Optional[str]: def _format_table(table: Any) -> str: - """Convert a DOCX table to a markdown table string. - - Args: - table: A ``docx.table.Table`` instance. - - Returns: - Markdown table string (may be empty if the table has no rows). - """ rows: list[str] = [] for row in table.rows: cells = [cell.text.strip() for cell in row.cells] @@ -74,17 +51,7 @@ def _format_table(table: Any) -> str: async def extract_docx_text(path: str) -> Optional[str]: - """Extract text from a DOCX and return markdown. - - Args: - path: Path to the DOCX file. - - Returns: - Markdown string if extraction yielded sufficient text, ``None`` otherwise. - - Raises: - SyncFailureError: If python-docx is not installed. - """ + """Extract text from a DOCX and return markdown.""" try: from docx import Document except ImportError: diff --git a/backend/airweave/platform/converters/text_extractors/pdf.py b/backend/airweave/domains/converters/text_extractors/pdf.py similarity index 63% rename from backend/airweave/platform/converters/text_extractors/pdf.py rename to backend/airweave/domains/converters/text_extractors/pdf.py index 94b2405f0..4ef1529d0 100644 --- a/backend/airweave/platform/converters/text_extractors/pdf.py +++ b/backend/airweave/domains/converters/text_extractors/pdf.py @@ -5,8 +5,8 @@ than OCR for documents that have a text layer. The module detects whether the whole PDF has extractable text: -- If all pages have sufficient text → return extracted content -- If any page is image-only → caller should use OCR for whole PDF +- If all pages have sufficient text -> return extracted content +- If any page is image-only -> caller should use OCR for whole PDF """ from __future__ import annotations @@ -18,20 +18,12 @@ from airweave.core.logging import logger from airweave.domains.sync_pipeline.exceptions import SyncFailureError -# Minimum characters per page to consider it "has text layer". -# Pages below this threshold are treated as image-only. MIN_CHARS_PER_PAGE = 50 @dataclass class PageExtractionResult: - """Result of attempting text extraction on a single page. - - Attributes: - page_num: 0-based page number. - text: Extracted text (empty string if extraction failed or page is image-only). - needs_ocr: True if this page should be sent to OCR. - """ + """Result of attempting text extraction on a single page.""" page_num: int text: str @@ -40,30 +32,25 @@ class PageExtractionResult: @dataclass class PdfExtractionResult: - """Result of extracting text from an entire PDF. - - Attributes: - path: Original PDF path. - pages: Per-page extraction results. - """ + """Result of extracting text from an entire PDF.""" path: str pages: list[PageExtractionResult] = field(default_factory=list) @property def full_text(self) -> str: - """Combined text from all extracted pages.""" + """Return concatenated text from all successfully extracted pages.""" texts = [p.text for p in self.pages if p.text and not p.needs_ocr] return "\n\n".join(texts) @property def pages_needing_ocr(self) -> list[int]: - """0-based page numbers that need OCR.""" + """Return page numbers that need OCR (image-only pages).""" return [p.page_num for p in self.pages if p.needs_ocr] @property def extraction_ratio(self) -> float: - """Fraction of pages that were successfully extracted.""" + """Return fraction of pages that were extracted without OCR.""" if not self.pages: return 0.0 extracted = sum(1 for p in self.pages if not p.needs_ocr) @@ -71,22 +58,12 @@ def extraction_ratio(self) -> float: @property def fully_extracted(self) -> bool: - """True if all pages were extracted without needing OCR.""" + """Return True if all pages have a text layer (no OCR needed).""" return bool(self.pages) and len(self.pages_needing_ocr) == 0 async def extract_pdf_text(path: str) -> PdfExtractionResult: - """Extract text from a PDF, detecting which pages need OCR. - - Args: - path: Path to the PDF file. - - Returns: - A :class:`PdfExtractionResult` with per-page extraction results. - - Raises: - SyncFailureError: If PyMuPDF is not installed. - """ + """Extract text from a PDF, detecting which pages need OCR.""" try: import fitz # PyMuPDF except ImportError: @@ -109,7 +86,6 @@ def _extract() -> PdfExtractionResult: finally: doc.close() - # Log summary name = os.path.basename(path) total = len(result.pages) extracted = total - len(result.pages_needing_ocr) @@ -130,64 +106,26 @@ def _extract() -> PdfExtractionResult: def _extract_page(page, page_num: int) -> PageExtractionResult: - """Extract text from a single PDF page. - - Args: - page: PyMuPDF page object. - page_num: 0-based page number. - - Returns: - A :class:`PageExtractionResult`. - """ try: - # Extract text with layout preservation text = page.get_text("text") char_count = len(text.strip()) if char_count < MIN_CHARS_PER_PAGE: - return PageExtractionResult( - page_num=page_num, - text="", - needs_ocr=True, - ) + return PageExtractionResult(page_num=page_num, text="", needs_ocr=True) - # Check if page is primarily images with minimal text image_list = page.get_images() if image_list and char_count < 200: - # Has images and very little text - likely a scan - return PageExtractionResult( - page_num=page_num, - text="", - needs_ocr=True, - ) + return PageExtractionResult(page_num=page_num, text="", needs_ocr=True) - # Text extraction successful - return PageExtractionResult( - page_num=page_num, - text=text.strip(), - needs_ocr=False, - ) + return PageExtractionResult(page_num=page_num, text=text.strip(), needs_ocr=False) except Exception as exc: logger.warning(f"Text extraction failed for page {page_num}: {exc}") - return PageExtractionResult( - page_num=page_num, - text="", - needs_ocr=True, - ) + return PageExtractionResult(page_num=page_num, text="", needs_ocr=True) def text_to_markdown(text: str) -> str: - """Convert extracted plain text to basic Markdown. - - Applies simple heuristics to detect headings, lists, and paragraphs. - - Args: - text: Raw extracted text. - - Returns: - Markdown-formatted text. - """ + """Convert extracted plain text to basic Markdown.""" if not text: return "" @@ -206,17 +144,14 @@ def text_to_markdown(text: str) -> str: prev_blank = False - # Detect potential headings (short lines, possibly all caps or title case) is_short = len(stripped) < 80 is_uppercase = stripped.isupper() and len(stripped) > 3 is_titlecase = stripped.istitle() and len(stripped) < 60 - # Detect bullet points if stripped.startswith(("• ", "· ", "- ", "* ", "◦ ")): result_lines.append(f"- {stripped[2:].strip()}") elif stripped.startswith(("1.", "2.", "3.", "4.", "5.", "6.", "7.", "8.", "9.")): result_lines.append(stripped) - # Potential heading elif is_short and (is_uppercase or is_titlecase) and not stripped.endswith((".", ",", ";")): if is_uppercase: result_lines.append(f"## {stripped.title()}") diff --git a/backend/airweave/platform/converters/text_extractors/pptx.py b/backend/airweave/domains/converters/text_extractors/pptx.py similarity index 64% rename from backend/airweave/platform/converters/text_extractors/pptx.py rename to backend/airweave/domains/converters/text_extractors/pptx.py index 0e5ceb828..9df686b7a 100644 --- a/backend/airweave/platform/converters/text_extractors/pptx.py +++ b/backend/airweave/domains/converters/text_extractors/pptx.py @@ -1,8 +1,4 @@ -"""Direct text extraction from PPTX files using python-pptx. - -Extracts text from slide shapes, tables, and notes to produce markdown. -Images and diagrams are not captured -- this is a text-only extraction. -""" +"""Direct text extraction from PPTX files using python-pptx.""" from __future__ import annotations @@ -13,21 +9,10 @@ from airweave.core.logging import logger from airweave.domains.sync_pipeline.exceptions import SyncFailureError -# Minimum total characters to consider the extraction successful. MIN_TOTAL_CHARS = 50 def _extract_shape_text(shape: Any) -> list[str]: - """Extract text lines from a single PPTX shape. - - Handles both text-frame shapes and table shapes. - - Args: - shape: A ``pptx.shapes.base.BaseShape`` instance. - - Returns: - List of text lines (may be empty). - """ lines: list[str] = [] if shape.has_text_frame: @@ -45,15 +30,6 @@ def _extract_shape_text(shape: Any) -> list[str]: def _extract_slide(slide: Any, slide_idx: int) -> str: - """Extract markdown for a single slide. - - Args: - slide: A ``pptx.slide.Slide`` instance. - slide_idx: 1-based slide number (used for the heading). - - Returns: - Markdown string for the slide. - """ parts: list[str] = [f"## Slide {slide_idx}"] for shape in slide.shapes: @@ -68,20 +44,7 @@ def _extract_slide(slide: Any, slide_idx: int) -> str: async def extract_pptx_text(path: str) -> Optional[str]: - """Extract text from a PPTX and return markdown. - - Iterates over slides, shapes, tables, and notes to produce a markdown - representation. Images and diagrams are not captured. - - Args: - path: Path to the PPTX file. - - Returns: - Markdown string if extraction yielded sufficient text, ``None`` otherwise. - - Raises: - SyncFailureError: If python-pptx is not installed. - """ + """Extract text from a PPTX and return markdown.""" try: from pptx import Presentation except ImportError: diff --git a/backend/airweave/platform/converters/txt_converter.py b/backend/airweave/domains/converters/txt.py similarity index 75% rename from backend/airweave/platform/converters/txt_converter.py rename to backend/airweave/domains/converters/txt.py index 6592ea0a2..512609666 100644 --- a/backend/airweave/platform/converters/txt_converter.py +++ b/backend/airweave/domains/converters/txt.py @@ -9,48 +9,32 @@ import aiofiles from airweave.core.logging import logger +from airweave.domains.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.converters._base import BaseTextConverter class TxtConverter(BaseTextConverter): - """Converts text files (TXT, JSON, XML, MD, YAML, TOML) to markdown. - - Features: - - JSON: Pretty-prints with code fence - - XML: Pretty-prints with code fence - - Others: Returns as plain text - """ + """Converts text files (TXT, JSON, XML, MD, YAML, TOML) to markdown.""" async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: - """Convert text files to markdown. - - Args: - file_paths: List of text file paths - - Returns: - Dict mapping file_path -> markdown content (None if failed) - """ + """Convert text files to markdown.""" logger.debug(f"Converting {len(file_paths)} text files to markdown...") results = {} - semaphore = asyncio.Semaphore(20) # Limit concurrent file reads + semaphore = asyncio.Semaphore(20) async def _convert_one(path: str): async with semaphore: try: - # Determine format from extension _, ext = os.path.splitext(path) ext = ext.lower() - # Dispatch to format-specific handler if ext == ".json": text = await self._convert_json(path) elif ext == ".xml": text = await self._convert_xml(path) else: - # Plain text (TXT, MD, YAML, TOML, etc.) text = await self._convert_plain_text(path) if text and text.strip(): @@ -73,10 +57,6 @@ async def _convert_one(path: str): @staticmethod def _try_chardet_decode(raw_bytes: bytes, path: str) -> str | None: - """Attempt to decode bytes using chardet-detected encoding. - - Returns decoded text on success, None otherwise. - """ try: import chardet @@ -97,17 +77,6 @@ def _try_chardet_decode(raw_bytes: bytes, path: str) -> str | None: return None async def _convert_plain_text(self, path: str) -> str: - """Read plain text file with encoding detection. - - Args: - path: Path to text file - - Returns: - File content as string - - Raises: - EntityProcessingError: If file contains excessive binary/corrupted data - """ async with aiofiles.open(path, "rb") as f: raw_bytes = await f.read() @@ -146,31 +115,16 @@ async def _convert_plain_text(self, path: str) -> str: return text async def _convert_json(self, path: str) -> str: - """Convert JSON to pretty-printed code fence. - - Args: - path: Path to JSON file - - Returns: - Formatted JSON in markdown code fence - - Raises: - EntityProcessingError: If JSON syntax is invalid or contains corrupted data - """ - def _read_and_format(): - # Read raw bytes with open(path, "rb") as f: raw_bytes = f.read() - # Try UTF-8 first try: text = raw_bytes.decode("utf-8") except UnicodeDecodeError: - # Fallback with replace to detect corruption text = raw_bytes.decode("utf-8", errors="replace") replacement_count = text.count("\ufffd") - if replacement_count > 50: # Strict for JSON + if replacement_count > 50: raise EntityProcessingError( f"JSON contains binary data ({replacement_count} replacement chars)" ) @@ -186,31 +140,16 @@ def _read_and_format(): raise EntityProcessingError(f"Invalid JSON syntax in {path}") async def _convert_xml(self, path: str) -> str: - """Convert XML to pretty-printed code fence. - - Args: - path: Path to XML file - - Returns: - Formatted XML in markdown code fence - - Raises: - EntityProcessingError: If XML contains corrupted data - """ - def _read_and_format(): - # Read raw bytes with open(path, "rb") as f: raw_bytes = f.read() - # Try UTF-8 first try: content = raw_bytes.decode("utf-8") except UnicodeDecodeError: - # Fallback with replace to detect corruption content = raw_bytes.decode("utf-8", errors="replace") replacement_count = content.count("\ufffd") - if replacement_count > 50: # Strict for XML + if replacement_count > 50: raise EntityProcessingError( f"XML contains binary data ({replacement_count} replacement chars)" ) @@ -225,7 +164,6 @@ def _read_and_format(): raise except Exception as e: logger.warning(f"XML parsing failed for {path}: {e}, using raw content") - # Fallback to raw content - read with validation with open(path, "rb") as f: raw_bytes = f.read() @@ -234,7 +172,7 @@ def _read_and_format(): except UnicodeDecodeError: raw = raw_bytes.decode("utf-8", errors="replace") replacement_count = raw.count("\ufffd") - if replacement_count > 100: # More lenient for fallback + if replacement_count > 100: raise EntityProcessingError( f"XML contains excessive binary data " f"({replacement_count} replacement chars)" diff --git a/backend/airweave/platform/converters/web_converter.py b/backend/airweave/domains/converters/web.py similarity index 58% rename from backend/airweave/platform/converters/web_converter.py rename to backend/airweave/domains/converters/web.py index d243a2536..8c3f848e5 100644 --- a/backend/airweave/platform/converters/web_converter.py +++ b/backend/airweave/domains/converters/web.py @@ -8,50 +8,34 @@ from airweave.core.config import settings from airweave.core.logging import logger +from airweave.domains.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.platform.converters._base import BaseTextConverter from airweave.platform.rate_limiters import FirecrawlRateLimiter -# ==================== CONFIGURATION ==================== - -# Retry configuration MAX_RETRIES = 3 -RETRY_MIN_WAIT = 10 # seconds -RETRY_MAX_WAIT = 120 # seconds (longer for rate limits) +RETRY_MIN_WAIT = 10 +RETRY_MAX_WAIT = 120 RETRY_MULTIPLIER = 2 -# Batch job polling POLL_INTERVAL_SECONDS = 2 -POLL_TIMEOUT_SECONDS = 600 # 10 minutes max for a batch +POLL_TIMEOUT_SECONDS = 600 class WebConverter(BaseTextConverter): """Converter that fetches URLs and converts HTML to markdown. Uses Firecrawl batch scrape API to efficiently process multiple URLs. - Returns markdown content for each URL. - - Error handling: - - Per-URL failures: Returns None for that URL (entity will be skipped) - - Batch failures: Returns all None (all entities in batch will be skipped) - - Infrastructure failures (API key, auth, quota): Raises SyncFailureError (fails entire sync) """ - # Batch size from Firecrawl Growth plan (100 concurrent browsers) BATCH_SIZE = FirecrawlRateLimiter.FIRECRAWL_CONCURRENT_BROWSERS def __init__(self): """Initialize the web converter with lazy Firecrawl client.""" - self.rate_limiter = FirecrawlRateLimiter() # Singleton - shared across pod + self.rate_limiter = FirecrawlRateLimiter() self._firecrawl_client: Optional[Any] = None self._initialized = False def _ensure_client(self): - """Ensure Firecrawl client is initialized (lazy initialization). - - Raises: - SyncFailureError: If API key not configured or package not installed - """ if self._initialized: return @@ -69,48 +53,25 @@ def _ensure_client(self): raise SyncFailureError("firecrawl-py package required but not installed") async def convert_batch(self, urls: List[str]) -> Dict[str, str]: - """Fetch URLs and convert to markdown using Firecrawl batch scrape. - - Args: - urls: List of URLs to fetch and convert - - Returns: - Dict mapping URL -> markdown content (None if that URL failed). - Even if the entire batch fails, returns all None values so entities - can be skipped individually rather than failing the entire sync. - - Raises: - SyncFailureError: Only for true infrastructure failures (API key missing, - unauthorized, forbidden, payment required, quota exceeded) - """ + """Fetch URLs and convert to markdown using Firecrawl batch scrape.""" if not urls: return {} - # Ensure client is initialized (raises SyncFailureError if not possible) self._ensure_client() - # Initialize all URLs as None (failed) - will be updated with successful results results: Dict[str, str] = {url: None for url in urls} try: - # Rate limit before API call (batch counts as 1 request) await self.rate_limiter.acquire() - - # Start batch scrape and wait for completion batch_result = await self._batch_scrape_with_retry(urls) - - # Extract results - updates dict with successful conversions self._extract_results(urls, batch_result, results) - return results except SyncFailureError: - # Infrastructure failure - propagate to fail sync raise except Exception as e: error_msg = str(e).lower() - # Check for infrastructure failures that should fail the sync is_infrastructure = any( kw in error_msg for kw in [ @@ -127,27 +88,10 @@ async def convert_batch(self, urls: List[str]) -> Dict[str, str]: logger.error(f"Firecrawl infrastructure failure: {e}") raise SyncFailureError(f"Firecrawl infrastructure failure: {e}") - # Other errors (timeout, network issues) - log but return partial results - # Individual URL failures are already handled by returning None - # Even if entire batch fails, return all None - entities will be skipped individually logger.warning(f"Firecrawl batch scrape error (entities will be skipped): {e}") - - # Return partial results - URLs with None will be skipped by entity pipeline return results async def _batch_scrape_with_retry(self, urls: List[str]): - """Execute batch scrape with retry logic. - - Args: - urls: List of URLs to scrape - - Returns: - Firecrawl batch scrape result object - - Raises: - Exception: If all retries fail - """ - @retry( retry=retry_if_exception_type( (TimeoutException, ReadTimeout, HTTPStatusError, asyncio.TimeoutError) @@ -159,7 +103,6 @@ async def _batch_scrape_with_retry(self, urls: List[str]): reraise=True, ) async def _call(): - # batch_scrape polls internally until complete return await self._firecrawl_client.batch_scrape( urls, formats=["markdown"], @@ -170,48 +113,30 @@ async def _call(): return await _call() def _extract_results(self, urls: List[str], batch_result, results: Dict[str, str]) -> None: - """Extract markdown content from batch scrape result. - - Updates results dict in-place with successful conversions. - URLs that fail remain as None in the dict. - - Args: - urls: Original list of URLs - batch_result: Firecrawl batch scrape result object - results: Dict to update (already initialized with all URLs -> None) - """ - # Check if we got any data if not hasattr(batch_result, "data") or not batch_result.data: logger.warning("Firecrawl batch returned no data") return - # Process each document in the result for doc in batch_result.data: - # Extract source URL from metadata source_url = self._get_source_url(doc) if not source_url: logger.warning("Firecrawl doc missing sourceURL in metadata") continue - # Extract markdown content markdown = getattr(doc, "markdown", None) if not markdown: logger.warning(f"Firecrawl returned no markdown for {source_url}") - # Leave as None in results continue - # Match back to original URL (handle trailing slashes etc) matched_url = self._match_url(source_url, urls) if matched_url: results[matched_url] = markdown elif source_url in results: - # Fallback: use source_url directly if it was in input results[source_url] = markdown else: logger.warning(f"Could not match Firecrawl result URL: {source_url}") - # Log summary successful = sum(1 for v in results.values() if v is not None) failed = len(results) - successful @@ -225,51 +150,26 @@ def _extract_results(self, urls: List[str], batch_result, results: Dict[str, str logger.debug(f"Firecrawl: all {successful} URLs converted successfully") def _get_source_url(self, doc) -> Optional[str]: - """Extract source URL from Firecrawl document metadata. - - Args: - doc: Firecrawl document object - - Returns: - Source URL string or None - """ if not hasattr(doc, "metadata") or not doc.metadata: return None - # Firecrawl v4 uses snake_case: source_url - # Try attribute access first (for typed objects) source_url = getattr(doc.metadata, "source_url", None) if source_url: return source_url - # Fallback: try camelCase for older SDK versions source_url = getattr(doc.metadata, "sourceURL", None) if source_url: return source_url - # Try dict access (for untyped dicts) if isinstance(doc.metadata, dict): return doc.metadata.get("source_url") or doc.metadata.get("sourceURL") return None def _match_url(self, source_url: str, original_urls: List[str]) -> Optional[str]: - """Match a source URL back to the original URL list. - - Handles minor differences like trailing slashes. - - Args: - source_url: URL from Firecrawl response - original_urls: List of original input URLs - - Returns: - Matched original URL or None - """ - # Exact match if source_url in original_urls: return source_url - # Try normalized comparison (trailing slashes) normalized_source = source_url.rstrip("/") for url in original_urls: if url.rstrip("/") == normalized_source: diff --git a/backend/airweave/platform/converters/xlsx_converter.py b/backend/airweave/domains/converters/xlsx.py similarity index 72% rename from backend/airweave/platform/converters/xlsx_converter.py rename to backend/airweave/domains/converters/xlsx.py index ccc5208d4..59c7707d3 100644 --- a/backend/airweave/platform/converters/xlsx_converter.py +++ b/backend/airweave/domains/converters/xlsx.py @@ -4,31 +4,16 @@ from typing import Dict, List from airweave.core.logging import logger +from airweave.domains.converters._base import BaseTextConverter from airweave.domains.sync_pipeline.async_helpers import run_in_thread_pool from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.converters._base import BaseTextConverter class XlsxConverter(BaseTextConverter): - """Converts XLSX files to markdown using local openpyxl extraction. - - Note: XLSX is not supported by Mistral OCR, so we use local extraction. - Extracts all sheets as markdown tables with formulas and cell values. - """ + """Converts XLSX files to markdown using local openpyxl extraction.""" async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: - """Convert XLSX files to markdown text using openpyxl. - - Args: - file_paths: List of XLSX file paths to convert - - Returns: - Dict mapping file_path -> markdown text content (None if failed) - - Raises: - SyncFailureError: If openpyxl package not installed - """ - # Check package availability upfront + """Convert XLSX files to markdown text using openpyxl.""" try: import openpyxl # noqa: F401 except ImportError: @@ -39,7 +24,7 @@ async def convert_batch(self, file_paths: List[str]) -> Dict[str, str]: logger.debug(f"Converting {len(file_paths)} XLSX files to markdown...") results = {} - semaphore = asyncio.Semaphore(10) # Limit concurrent file reads + semaphore = asyncio.Semaphore(10) async def _convert_one(path: str): async with semaphore: @@ -68,23 +53,10 @@ async def _convert_one(path: str): return results async def _extract_xlsx_to_markdown(self, xlsx_path: str) -> str: # noqa: C901 - """Extract XLSX content to markdown format. - - Args: - xlsx_path: Path to XLSX file - - Returns: - Markdown formatted string with all sheets - - Raises: - EntityProcessingError: If file cannot be opened or has no sheets - """ - def _extract() -> str: # noqa: C901 from openpyxl import load_workbook try: - # Load workbook with formula evaluation wb = load_workbook(xlsx_path, data_only=False) except Exception as e: raise EntityProcessingError(f"Failed to open XLSX file {xlsx_path}: {e}") @@ -96,28 +68,22 @@ def _extract() -> str: # noqa: C901 markdown_parts = [] - # Process each sheet for sheet_name in sheet_names: sheet = wb[sheet_name] - # Get max row and column max_row = sheet.max_row max_col = sheet.max_column if max_row == 0 or max_col == 0: - # Empty sheet - skip logger.debug(f"Sheet '{sheet_name}' is empty, skipping") continue - # Add sheet header markdown_parts.append(f"## Sheet: {sheet_name}\n") - # Extract all rows rows_data = [] for row in sheet.iter_rows(min_row=1, max_row=max_row, max_col=max_col): row_values = [] for cell in row: - # Get cell value (formulas will be evaluated if data_only=True) value = cell.value if value is None: row_values.append("") @@ -129,32 +95,23 @@ def _extract() -> str: # noqa: C901 markdown_parts.append("*Empty sheet*\n") continue - # Convert to markdown table - # Use first row as header if len(rows_data) > 1: header = rows_data[0] data_rows = rows_data[1:] - # Create markdown table - # Header row markdown_parts.append("| " + " | ".join(header) + " |") - # Separator row markdown_parts.append("| " + " | ".join(["---"] * len(header)) + " |") - # Data rows for row in data_rows: - # Pad row if shorter than header padded_row = row + [""] * (len(header) - len(row)) markdown_parts.append("| " + " | ".join(padded_row[: len(header)]) + " |") else: - # Single row - just show as list for value in rows_data[0]: if value: markdown_parts.append(f"- {value}") - markdown_parts.append("") # Blank line between sheets + markdown_parts.append("") - # Combine all sheets if not markdown_parts: raise EntityProcessingError(f"XLSX file {xlsx_path} has no extractable content") diff --git a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py index c6fe8d43a..84788abdb 100644 --- a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from airweave.core.shared_models import AirweaveFieldFlag +from airweave.domains.converters.protocols import ConverterRegistryProtocol from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError from airweave.domains.sync_pipeline.file_types import SUPPORTED_FILE_EXTENSIONS from airweave.platform.entities._base import BaseEntity, CodeFileEntity, FileEntity, WebEntity @@ -24,9 +25,12 @@ class TextualRepresentationBuilder: - Batch conversion orchestration """ - # Default batch size for converters without specific config DEFAULT_CONVERTER_BATCH_SIZE = 10 + def __init__(self, converter_registry: ConverterRegistryProtocol) -> None: + """Initialize with a converter registry for routing entities to converters.""" + self._registry = converter_registry + # ------------------------------------------------------------------------------------ # Public API # ------------------------------------------------------------------------------------ @@ -272,107 +276,26 @@ def _partition_by_converter( return converter_groups, failed_entities def _get_converter_and_key(self, entity: BaseEntity) -> Tuple[Any, Optional[str]]: - """Get the appropriate converter and key for an entity. - - Args: - entity: The entity to get converter for - - Returns: - Tuple of (converter, key) where: - - converter: The converter module/instance to use - - key: The key to pass to convert_batch - - Raises: - EntityProcessingError: If entity type is not supported or missing required fields - """ - from airweave.platform import converters - - # WebEntity: use web_converter with crawl_url + """Get the appropriate converter and key for an entity.""" if isinstance(entity, WebEntity): if not entity.crawl_url: raise EntityProcessingError(f"WebEntity {entity.entity_id} missing crawl_url") - return converters.web_converter, entity.crawl_url + return self._registry.for_web(), entity.crawl_url - # FileEntity: use file-type specific converter with local_path if isinstance(entity, FileEntity): if not entity.local_path: raise EntityProcessingError(f"FileEntity {entity.entity_id} missing local_path") - converter = self._determine_converter_for_file(entity.local_path) + _, ext = os.path.splitext(entity.local_path) + ext = ext.lower() + if ext not in SUPPORTED_FILE_EXTENSIONS: + raise EntityProcessingError(f"Unsupported file type: {ext}") + converter = self._registry.for_extension(ext) + if not converter: + raise EntityProcessingError(f"Unsupported file type: {ext}") return converter, entity.local_path - # Other entity types don't need content conversion return None, None - def _determine_converter_for_file(self, file_path: str) -> Any: - """Determine converter module based on file extension. - - Args: - file_path: Path to the file - - Returns: - Converter module with convert_batch function - - Raises: - EntityProcessingError: If file type is not supported - """ - from airweave.platform import converters - - _, ext = os.path.splitext(file_path) - ext = ext.lower() - - if ext not in SUPPORTED_FILE_EXTENSIONS: - raise EntityProcessingError(f"Unsupported file type: {ext}") - - converter_map = { - # Documents - Text extraction + Mistral OCR fallback - ".pdf": converters.pdf_converter, - ".docx": converters.docx_converter, - ".pptx": converters.pptx_converter, - # Mistral OCR - Images - ".jpg": converters.mistral_converter, - ".jpeg": converters.mistral_converter, - ".png": converters.mistral_converter, - # XLSX - local extraction - ".xlsx": converters.xlsx_converter, - # HTML - ".html": converters.html_converter, - ".htm": converters.html_converter, - # Text files - ".txt": converters.txt_converter, - ".json": converters.txt_converter, - ".xml": converters.txt_converter, - ".md": converters.txt_converter, - ".yaml": converters.txt_converter, - ".yml": converters.txt_converter, - ".toml": converters.txt_converter, - # Code file extensions - ".py": converters.code_converter, - ".js": converters.code_converter, - ".ts": converters.code_converter, - ".tsx": converters.code_converter, - ".jsx": converters.code_converter, - ".java": converters.code_converter, - ".cpp": converters.code_converter, - ".c": converters.code_converter, - ".h": converters.code_converter, - ".hpp": converters.code_converter, - ".go": converters.code_converter, - ".rs": converters.code_converter, - ".rb": converters.code_converter, - ".php": converters.code_converter, - ".swift": converters.code_converter, - ".kt": converters.code_converter, - ".kts": converters.code_converter, - ".tf": converters.code_converter, - ".tfvars": converters.code_converter, - } - - converter = converter_map.get(ext) - if not converter: - raise EntityProcessingError(f"Unsupported file type: {ext}") - - return converter - # ------------------------------------------------------------------------------------ # Conversion Execution # ------------------------------------------------------------------------------------ @@ -492,7 +415,3 @@ async def _handle_conversion_failures( sync_context.logger.warning( f"Removed {len(failed_entities)} entities that failed conversion" ) - - -# Singleton instance -text_builder = TextualRepresentationBuilder() diff --git a/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py index 28b6f54d5..8e704cd92 100644 --- a/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/processors/chunk_embed.py @@ -14,8 +14,9 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from airweave.domains.converters.protocols import ConverterRegistryProtocol from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder +from airweave.domains.sync_pipeline.pipeline.text_builder import TextualRepresentationBuilder from airweave.domains.sync_pipeline.processors.utils import filter_empty_representations from airweave.platform.entities._base import BaseEntity, CodeFileEntity @@ -25,24 +26,11 @@ class ChunkEmbedProcessor: - """Unified processor that chunks text and computes embeddings. - - Pipeline: - 1. Build textual representation (text extraction from files/web) - 2. Chunk text (semantic for text, AST for code) - 3. Compute embeddings: - - Dense embeddings (3072-dim for neural/semantic search) - - Sparse embeddings (FastEmbed Qdrant/bm25 for keyword search scoring) - - Output: - Chunk entities with: - - entity_id: "{original_id}__chunk_{idx}" - - textual_representation: chunk text - - airweave_system_metadata.dense_embedding: 3072-dim vector - - airweave_system_metadata.sparse_embedding: FastEmbed BM25 sparse vector - - airweave_system_metadata.original_entity_id: original entity_id - - airweave_system_metadata.chunk_index: chunk position - """ + """Unified processor that chunks text and computes embeddings.""" + + def __init__(self, converter_registry: ConverterRegistryProtocol) -> None: + """Initialize with a converter registry for text building.""" + self._text_builder = TextualRepresentationBuilder(converter_registry) async def process( self, @@ -55,7 +43,7 @@ async def process( return [] # Step 1: Build textual representations - processed = await text_builder.build_for_batch(entities, sync_context, runtime) + processed = await self._text_builder.build_for_batch(entities, sync_context, runtime) # Step 2: Filter empty representations processed = await filter_empty_representations( diff --git a/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py index aef34f233..53bf02496 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_chunk_embed.py @@ -4,21 +4,22 @@ import pytest +from airweave.domains.converters.fakes.registry import FakeConverterRegistry from airweave.domains.sync_pipeline.processors.chunk_embed import ChunkEmbedProcessor -_TEXT_BUILDER = "airweave.domains.sync_pipeline.processors.chunk_embed.text_builder" +_TEXT_BUILDER_CLS = ( + "airweave.domains.sync_pipeline.processors.chunk_embed.TextualRepresentationBuilder" +) _SEMANTIC_CHUNKER = "airweave.platform.chunkers.semantic.SemanticChunker" @pytest.fixture def processor(): - """Create ChunkEmbedProcessor instance.""" - return ChunkEmbedProcessor() + return ChunkEmbedProcessor(converter_registry=FakeConverterRegistry()) @pytest.fixture def mock_sync_context(): - """Create mock SyncContext.""" context = MagicMock() context.logger = MagicMock() context.collection = MagicMock() @@ -27,7 +28,6 @@ def mock_sync_context(): @pytest.fixture def mock_runtime(): - """Create mock SyncRuntime.""" runtime = MagicMock() runtime.entity_tracker = AsyncMock() runtime.dense_embedder = MagicMock() @@ -40,7 +40,6 @@ def mock_runtime(): @pytest.fixture def mock_entity(): - """Create a simple mock entity.""" entity = MagicMock() entity.entity_id = "test-123" entity.textual_representation = "Test content" @@ -54,11 +53,9 @@ def mock_entity(): class TestChunkEmbedProcessor: - """Test ChunkEmbedProcessor chunks text and computes embeddings.""" @pytest.mark.asyncio async def test_process_empty_list(self, processor, mock_sync_context, mock_runtime): - """Test processing empty entity list returns empty.""" result = await processor.process([], mock_sync_context, mock_runtime) assert result == [] @@ -66,13 +63,14 @@ async def test_process_empty_list(self, processor, mock_sync_context, mock_runti async def test_chunk_textual_entities_uses_semantic_chunker( self, processor, mock_sync_context, mock_runtime, mock_entity ): - """Test textual entities routed to SemanticChunker.""" with ( - patch(_TEXT_BUILDER) as mock_builder, + patch.object( + processor._text_builder, "build_for_batch", new_callable=AsyncMock + ) as mock_build, patch(_SEMANTIC_CHUNKER) as MockSemanticChunker, patch.object(processor, "_embed_entities", new_callable=AsyncMock), ): - mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) + mock_build.return_value = [mock_entity] mock_chunker = MockSemanticChunker.return_value mock_chunker.chunk_batch = AsyncMock( return_value=[[{"text": "Chunk 1"}, {"text": "Chunk 2"}]] @@ -80,20 +78,15 @@ async def test_chunk_textual_entities_uses_semantic_chunker( await processor.process([mock_entity], mock_sync_context, mock_runtime) - # Verify SemanticChunker was called mock_chunker.chunk_batch.assert_called_once() @pytest.mark.asyncio async def test_multiply_entities_creates_chunk_suffix(self, processor, mock_sync_context): - """Test chunk entity creation with proper ID suffix.""" - # Create mock entity mock_entity = MagicMock() mock_entity.entity_id = "parent-123" mock_entity.textual_representation = "Original text" mock_entity.airweave_system_metadata = MagicMock() - mock_entity.model_copy = MagicMock(return_value=MagicMock()) - # Configure model_copy to return new mock with modifiable attributes def create_chunk_entity(deep=False): chunk = MagicMock() chunk.entity_id = None @@ -106,17 +99,14 @@ def create_chunk_entity(deep=False): mock_entity.model_copy = MagicMock(side_effect=create_chunk_entity) chunks = [[{"text": "Chunk 0"}, {"text": "Chunk 1"}]] - result = processor._multiply_entities([mock_entity], chunks, mock_sync_context) assert len(result) == 2 - # Check that entity IDs have chunk suffix assert "__chunk_0" in result[0].entity_id assert "__chunk_1" in result[1].entity_id @pytest.mark.asyncio async def test_multiply_entities_sets_chunk_index(self, processor, mock_sync_context): - """Test chunk index set correctly.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" @@ -132,14 +122,12 @@ def create_chunk_entity(deep=False): mock_entity.model_copy = MagicMock(side_effect=create_chunk_entity) chunks = [[{"text": "Chunk"}]] - result = processor._multiply_entities([mock_entity], chunks, mock_sync_context) assert result[0].airweave_system_metadata.chunk_index == 0 @pytest.mark.asyncio async def test_multiply_entities_skips_empty_chunks(self, processor, mock_sync_context): - """Test empty chunks are filtered out.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" @@ -155,37 +143,29 @@ def create_chunk_entity(deep=False): mock_entity.model_copy = MagicMock(side_effect=create_chunk_entity) chunks = [[{"text": "Valid"}, {"text": ""}, {"text": " "}, {"text": "Another"}]] - result = processor._multiply_entities([mock_entity], chunks, mock_sync_context) - # Should only have 2 chunks (empty ones filtered) assert len(result) == 2 @pytest.mark.asyncio async def test_embed_entities_calls_both_embedders(self, processor, mock_runtime): - """Test both dense and sparse embedders are called.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test content" mock_entity.airweave_system_metadata = MagicMock() mock_entity.model_dump = MagicMock(return_value={"entity_id": "test"}) - chunk_entities = [mock_entity] - - # Setup runtime embedder mocks dense_result = MagicMock() dense_result.vector = [0.1] * 3072 mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - await processor._embed_entities(chunk_entities, mock_runtime) + await processor._embed_entities([mock_entity], mock_runtime) - # Verify both embedders called mock_runtime.dense_embedder.embed_many.assert_called_once() mock_runtime.sparse_embedder.embed_many.assert_called_once() @pytest.mark.asyncio async def test_embed_entities_assigns_embeddings(self, processor, mock_runtime): - """Test embeddings assigned to entity system metadata.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() @@ -193,8 +173,6 @@ async def test_embed_entities_assigns_embeddings(self, processor, mock_runtime): mock_entity.airweave_system_metadata.sparse_embedding = None mock_entity.model_dump = MagicMock(return_value={"entity_id": "test"}) - chunk_entities = [mock_entity] - dense_vector = [0.1] * 3072 dense_result = MagicMock() dense_result.vector = dense_vector @@ -203,40 +181,31 @@ async def test_embed_entities_assigns_embeddings(self, processor, mock_runtime): mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[sparse_embedding]) - await processor._embed_entities(chunk_entities, mock_runtime) + await processor._embed_entities([mock_entity], mock_runtime) - # Check embeddings assigned assert mock_entity.airweave_system_metadata.dense_embedding == dense_vector assert mock_entity.airweave_system_metadata.sparse_embedding == sparse_embedding @pytest.mark.asyncio async def test_embed_entities_uses_full_json_for_sparse(self, processor, mock_runtime): - """Test sparse embedder receives full entity JSON.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() mock_entity.model_dump = MagicMock( - return_value={ - "entity_id": "test-123", - "name": "Test Entity", - } + return_value={"entity_id": "test-123", "name": "Test Entity"} ) - chunk_entities = [mock_entity] - dense_result = MagicMock() dense_result.vector = [0.1] * 3072 mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - await processor._embed_entities(chunk_entities, mock_runtime) + await processor._embed_entities([mock_entity], mock_runtime) - # Verify sparse embedder got JSON strings call_args = mock_runtime.sparse_embedder.embed_many.call_args[0][0] assert isinstance(call_args, list) assert isinstance(call_args[0], str) - # Verify it's JSON import json parsed = json.loads(call_args[0]) @@ -244,30 +213,24 @@ async def test_embed_entities_uses_full_json_for_sparse(self, processor, mock_ru @pytest.mark.asyncio async def test_embed_entities_validates_embeddings_exist(self, processor, mock_runtime): - """Test validation that all entities have embeddings.""" mock_entity = MagicMock() mock_entity.textual_representation = "Test" mock_entity.entity_id = "test-123" mock_entity.airweave_system_metadata = MagicMock() mock_entity.model_dump = MagicMock(return_value={"entity_id": "test"}) - chunk_entities = [mock_entity] - - # Return None for dense embedding vector dense_result = MagicMock() dense_result.vector = None mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - # Should raise error with pytest.raises(Exception) as exc_info: - await processor._embed_entities(chunk_entities, mock_runtime) + await processor._embed_entities([mock_entity], mock_runtime) assert "no dense embedding" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_full_pipeline_with_mocks(self, processor, mock_sync_context, mock_runtime): - """Test full pipeline with all mocked dependencies.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" mock_entity.textual_representation = "Original text" @@ -294,8 +257,13 @@ def create_chunk(deep=False): ) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock(), MagicMock()]) - with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: - mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) + with ( + patch.object( + processor._text_builder, "build_for_batch", new_callable=AsyncMock + ) as mock_build, + patch(_SEMANTIC_CHUNKER) as MockChunker, + ): + mock_build.return_value = [mock_entity] mock_chunker = MockChunker.return_value mock_chunker.chunk_batch = AsyncMock( @@ -304,10 +272,8 @@ def create_chunk(deep=False): result = await processor.process([mock_entity], mock_sync_context, mock_runtime) - # Should have 2 chunks assert len(result) == 2 - # Verify pipeline steps were called - mock_builder.build_for_batch.assert_called_once() + mock_build.assert_called_once() mock_chunker.chunk_batch.assert_called_once() mock_runtime.dense_embedder.embed_many.assert_called_once() mock_runtime.sparse_embedder.embed_many.assert_called_once() @@ -316,7 +282,6 @@ def create_chunk(deep=False): async def test_memory_optimization_clears_parent_text( self, processor, mock_sync_context, mock_runtime ): - """Test parent entity text released after chunking.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" mock_entity.textual_representation = "Original text" @@ -335,50 +300,57 @@ def create_chunk(deep=False): mock_runtime.dense_embedder.embed_many = AsyncMock(return_value=[dense_result]) mock_runtime.sparse_embedder.embed_many = AsyncMock(return_value=[MagicMock()]) - with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: - mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) + with ( + patch.object( + processor._text_builder, "build_for_batch", new_callable=AsyncMock + ) as mock_build, + patch(_SEMANTIC_CHUNKER) as MockChunker, + ): + mock_build.return_value = [mock_entity] mock_chunker = MockChunker.return_value mock_chunker.chunk_batch = AsyncMock(return_value=[[{"text": "Chunk"}]]) await processor.process([mock_entity], mock_sync_context, mock_runtime) - # Parent entity's textual_representation should be None assert mock_entity.textual_representation is None @pytest.mark.asyncio async def test_skips_entities_without_text(self, processor, mock_sync_context, mock_runtime): - """Test entities with no textual_representation are skipped.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" - mock_entity.textual_representation = None # No text + mock_entity.textual_representation = None mock_entity.airweave_system_metadata = MagicMock() - with patch(_TEXT_BUILDER) as mock_builder: - mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) + with patch.object( + processor._text_builder, "build_for_batch", new_callable=AsyncMock + ) as mock_build: + mock_build.return_value = [mock_entity] result = await processor.process([mock_entity], mock_sync_context, mock_runtime) - # Should return empty list (skipped) assert len(result) == 0 @pytest.mark.asyncio async def test_handles_empty_chunks_from_chunker( self, processor, mock_sync_context, mock_runtime ): - """Test handling when chunker returns empty list.""" mock_entity = MagicMock() mock_entity.entity_id = "test-123" mock_entity.textual_representation = "Test" mock_entity.airweave_system_metadata = MagicMock() - with patch(_TEXT_BUILDER) as mock_builder, patch(_SEMANTIC_CHUNKER) as MockChunker: - mock_builder.build_for_batch = AsyncMock(return_value=[mock_entity]) + with ( + patch.object( + processor._text_builder, "build_for_batch", new_callable=AsyncMock + ) as mock_build, + patch(_SEMANTIC_CHUNKER) as MockChunker, + ): + mock_build.return_value = [mock_entity] mock_chunker = MockChunker.return_value - mock_chunker.chunk_batch = AsyncMock(return_value=[[]]) # Empty chunks + mock_chunker.chunk_batch = AsyncMock(return_value=[[]]) result = await processor.process([mock_entity], mock_sync_context, mock_runtime) - # Should skip entity with no chunks assert len(result) == 0 diff --git a/backend/airweave/main.py b/backend/airweave/main.py index b095b4200..fd9229829 100644 --- a/backend/airweave/main.py +++ b/backend/airweave/main.py @@ -63,11 +63,6 @@ async def lifespan(app: FastAPI): initialize_container(settings) logger.info("Container initialized successfully") - # Initialize converters with OCR from the container - from airweave.platform.converters import initialize_converters - - initialize_converters(ocr_provider=container_mod.container.ocr_provider) - async with AsyncSessionLocal() as db: if settings.RUN_ALEMBIC_MIGRATIONS: logger.info("Running alembic migrations...") diff --git a/backend/airweave/platform/converters/__init__.py b/backend/airweave/platform/converters/__init__.py deleted file mode 100644 index 602ce5ded..000000000 --- a/backend/airweave/platform/converters/__init__.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Text converters for converting files and URLs to markdown. - -Converter singletons are initialized explicitly at startup via -``initialize_converters()``. OCR is injected as a parameter — the -converters module never imports the DI container. -""" - -import sys -from typing import TYPE_CHECKING - -from .code_converter import CodeConverter -from .docx_converter import DocxConverter -from .html_converter import HtmlConverter -from .pdf_converter import PdfConverter -from .pptx_converter import PptxConverter -from .txt_converter import TxtConverter -from .web_converter import WebConverter -from .xlsx_converter import XlsxConverter - -if TYPE_CHECKING: - from airweave.core.protocols import OcrProvider - -# --------------------------------------------------------------------------- -# Singleton management -# --------------------------------------------------------------------------- -# -# ``from .pdf_converter import PdfConverter`` also adds the *module* -# ``pdf_converter`` as an attribute of this package. That shadows the -# singleton of the same name and prevents ``__getattr__`` from firing. -# Remove the module references so the singleton lookup works correctly. -# (The submodules remain in ``sys.modules`` so direct imports still work.) - -_SINGLETON_NAMES = frozenset( - { - "mistral_converter", - "pdf_converter", - "docx_converter", - "pptx_converter", - "img_converter", - "html_converter", - "txt_converter", - "xlsx_converter", - "code_converter", - "web_converter", - } -) - -for _mod in ( - "code_converter", - "docx_converter", - "html_converter", - "pdf_converter", - "pptx_converter", - "txt_converter", - "web_converter", - "xlsx_converter", -): - vars().pop(_mod, None) -del _mod - -_singletons: dict | None = None - - -def initialize_converters(ocr_provider: "OcrProvider | None" = None) -> None: - """Initialize converter singletons with the given OCR provider. - - Called once at startup from ``main.py`` lifespan and ``worker main()``. - The OCR provider is passed explicitly — no container import needed. - - When *ocr_provider* is ``None`` (no OCR credentials configured), the - hybrid document converters (PDF, DOCX, PPTX) still work for local text - extraction and only log a warning when OCR fallback would be needed. - - Args: - ocr_provider: The OCR provider (e.g., FallbackOcrProvider with - circuit breaking) to inject into document converters, or - ``None`` if OCR is unavailable. - """ - global _singletons - if _singletons is not None: - return - - _singletons = { - "mistral_converter": ocr_provider, - "pdf_converter": PdfConverter(ocr_provider=ocr_provider), - "docx_converter": DocxConverter(ocr_provider=ocr_provider), - "pptx_converter": PptxConverter(ocr_provider=ocr_provider), - "img_converter": ocr_provider, # Images go directly to OCR - "html_converter": HtmlConverter(), - "txt_converter": TxtConverter(), - "xlsx_converter": XlsxConverter(), - "code_converter": CodeConverter(), - "web_converter": WebConverter(), - } - - # Also set as module attributes so subsequent lookups are O(1) - # (bypasses __getattr__ after first access). - this_module = sys.modules[__name__] - for _name, _value in _singletons.items(): - setattr(this_module, _name, _value) - - -def __getattr__(name: str): - """PEP 562 module-level ``__getattr__`` for singleton access.""" - if name in _SINGLETON_NAMES: - if _singletons is None: - raise RuntimeError( - "Converters not initialized. Call initialize_converters() at startup." - ) - return _singletons[name] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = sorted(_SINGLETON_NAMES) diff --git a/backend/airweave/platform/converters/docx_converter.py b/backend/airweave/platform/converters/docx_converter.py deleted file mode 100644 index f9089c7b6..000000000 --- a/backend/airweave/platform/converters/docx_converter.py +++ /dev/null @@ -1,33 +0,0 @@ -"""DOCX converter with hybrid text extraction + OCR fallback. - -Uses :class:`HybridDocumentConverter` to try python-docx text extraction first -and fall back to OCR only when extraction is insufficient. -""" - -from __future__ import annotations - -from typing import Optional - -from airweave.platform.converters._base import HybridDocumentConverter -from airweave.platform.converters.text_extractors.docx import extract_docx_text - - -class DocxConverter(HybridDocumentConverter): - """Converts DOCX files to markdown using text extraction with OCR fallback. - - Most DOCX files have extractable text via python-docx. If extraction - yields insufficient content (e.g. the DOCX is mostly images), the file - is sent to the OCR provider. - - Usage:: - - converter = DocxConverter(ocr_provider=MistralOCR()) - results = await converter.convert_batch(["/tmp/doc.docx"]) - """ - - async def _try_extract(self, path: str) -> Optional[str]: - """Extract text from a DOCX using python-docx. - - Returns markdown if sufficient text was extracted, ``None`` otherwise. - """ - return await extract_docx_text(path) diff --git a/backend/airweave/platform/converters/pdf_converter.py b/backend/airweave/platform/converters/pdf_converter.py deleted file mode 100644 index 072f745ec..000000000 --- a/backend/airweave/platform/converters/pdf_converter.py +++ /dev/null @@ -1,41 +0,0 @@ -"""PDF converter with hybrid text extraction + OCR fallback. - -Uses :class:`HybridDocumentConverter` to try PyMuPDF text extraction first -and fall back to OCR only when pages lack a text layer. -""" - -from __future__ import annotations - -from typing import Optional - -from airweave.platform.converters._base import HybridDocumentConverter -from airweave.platform.converters.text_extractors.pdf import ( - extract_pdf_text, - text_to_markdown, -) - - -class PdfConverter(HybridDocumentConverter): - """Converts PDFs to markdown using text extraction with OCR fallback. - - For PDFs with embedded text layers on ALL pages, text is extracted directly - without any API calls. If any page lacks a text layer, the entire PDF is - sent to the OCR provider. - - Usage:: - - converter = PdfConverter(ocr_provider=MistralOCR()) - results = await converter.convert_batch(["/tmp/doc.pdf"]) - """ - - async def _try_extract(self, path: str) -> Optional[str]: - """Extract text from a PDF using PyMuPDF. - - Returns markdown if all pages have a text layer, ``None`` otherwise. - """ - extraction = await extract_pdf_text(path) - - if extraction.fully_extracted and extraction.full_text: - return text_to_markdown(extraction.full_text) - - return None diff --git a/backend/airweave/platform/converters/pptx_converter.py b/backend/airweave/platform/converters/pptx_converter.py deleted file mode 100644 index 20d818c48..000000000 --- a/backend/airweave/platform/converters/pptx_converter.py +++ /dev/null @@ -1,33 +0,0 @@ -"""PPTX converter with hybrid text extraction + OCR fallback. - -Uses :class:`HybridDocumentConverter` to try python-pptx text extraction first -and fall back to OCR only when extraction is insufficient. -""" - -from __future__ import annotations - -from typing import Optional - -from airweave.platform.converters._base import HybridDocumentConverter -from airweave.platform.converters.text_extractors.pptx import extract_pptx_text - - -class PptxConverter(HybridDocumentConverter): - """Converts PPTX files to markdown using text extraction with OCR fallback. - - Most PPTX files have extractable text via python-pptx. If extraction - yields insufficient content (e.g. slides are mostly images/diagrams), - the file is sent to the OCR provider. - - Usage:: - - converter = PptxConverter(ocr_provider=MistralOCR()) - results = await converter.convert_batch(["/tmp/slides.pptx"]) - """ - - async def _try_extract(self, path: str) -> Optional[str]: - """Extract text from a PPTX using python-pptx. - - Returns markdown if sufficient text was extracted, ``None`` otherwise. - """ - return await extract_pptx_text(path) diff --git a/backend/airweave/platform/converters/text_extractors/__init__.py b/backend/airweave/platform/converters/text_extractors/__init__.py deleted file mode 100644 index 3248db675..000000000 --- a/backend/airweave/platform/converters/text_extractors/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Text extraction utilities for various document formats. - -Re-exports all extractor functions and result types for convenience:: - - from airweave.platform.converters.text_extractors import extract_pdf_text - from airweave.platform.converters.text_extractors import extract_docx_text - from airweave.platform.converters.text_extractors import extract_pptx_text -""" - -from .docx import extract_docx_text -from .pdf import PdfExtractionResult, extract_pdf_text, text_to_markdown -from .pptx import extract_pptx_text - -__all__ = [ - "PdfExtractionResult", - "extract_pdf_text", - "text_to_markdown", - "extract_docx_text", - "extract_pptx_text", -] diff --git a/backend/airweave/platform/ocr/mistral/converter.py b/backend/airweave/platform/ocr/mistral/converter.py index db997485a..d4e10607c 100644 --- a/backend/airweave/platform/ocr/mistral/converter.py +++ b/backend/airweave/platform/ocr/mistral/converter.py @@ -25,8 +25,8 @@ import aiofiles.os from airweave.core.logging import logger +from airweave.domains.converters.text_extractors.pptx import extract_pptx_text from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.converters.text_extractors.pptx import extract_pptx_text from airweave.platform.ocr.mistral.compressor import compress_image from airweave.platform.ocr.mistral.models import ( IMAGE_EXTENSIONS, diff --git a/backend/airweave/platform/temporal/worker/__init__.py b/backend/airweave/platform/temporal/worker/__init__.py index 9644099b3..128c92ea0 100644 --- a/backend/airweave/platform/temporal/worker/__init__.py +++ b/backend/airweave/platform/temporal/worker/__init__.py @@ -183,12 +183,7 @@ async def main() -> None: ) raise SystemExit(1) - # 3. Initialize converters with OCR from the container - from airweave.platform.converters import initialize_converters - - initialize_converters(ocr_provider=container_mod.container.ocr_provider) - - # 4. Create worker with config + # 3. Create worker with config config = WorkerConfig.from_settings() worker = TemporalWorker(config) diff --git a/backend/conftest.py b/backend/conftest.py index e7372773d..907d503fc 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -409,6 +409,14 @@ def fake_access_broker(): return FakeAccessBroker() +@pytest.fixture +def fake_converter_registry(): + """Fake ConverterRegistry.""" + from airweave.domains.converters.fakes.registry import FakeConverterRegistry + + return FakeConverterRegistry() + + @pytest.fixture def fake_billing_webhook(): """Fake BillingWebhookProcessor.""" @@ -652,6 +660,7 @@ def test_container( fake_sync_factory, fake_entity_repo, fake_access_broker, + fake_converter_registry, ): """A Container with all dependencies replaced by fakes. @@ -724,4 +733,5 @@ def test_container( sync_factory=fake_sync_factory, entity_repo=fake_entity_repo, access_broker=fake_access_broker, + converter_registry=fake_converter_registry, ) diff --git a/backend/tests/unit/domains/converters/__init__.py b/backend/tests/unit/domains/converters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/unit/domains/converters/test_code.py b/backend/tests/unit/domains/converters/test_code.py new file mode 100644 index 000000000..46487ab3f --- /dev/null +++ b/backend/tests/unit/domains/converters/test_code.py @@ -0,0 +1,69 @@ +"""Unit tests for CodeConverter encoding validation.""" + +import os +import tempfile + +import pytest + +from airweave.domains.converters.code import CodeConverter + + +@pytest.fixture +def converter(): + return CodeConverter() + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +class TestCodeConverterEncodingValidation: + + @pytest.mark.asyncio + async def test_convert_clean_python_code(self, converter, temp_dir): + file_path = os.path.join(temp_dir, "clean.py") + code = """def hello_world(): + print("Hello, world!") + return True +""" + with open(file_path, "w", encoding="utf-8") as f: + f.write(code) + + result = await converter.convert_batch([file_path]) + + assert file_path in result + assert result[file_path] == code + + @pytest.mark.asyncio + async def test_convert_empty_code_file(self, converter, temp_dir): + file_path = os.path.join(temp_dir, "empty.py") + with open(file_path, "w", encoding="utf-8") as f: + f.write("") + + result = await converter.convert_batch([file_path]) + + assert file_path in result + assert result[file_path] is None + + @pytest.mark.asyncio + async def test_convert_batch_multiple_code_files(self, converter, temp_dir): + py_path = os.path.join(temp_dir, "script.py") + with open(py_path, "w", encoding="utf-8") as f: + f.write("print('Python')") + + js_path = os.path.join(temp_dir, "script.js") + with open(js_path, "w", encoding="utf-8") as f: + f.write("console.log('JavaScript');") + + result = await converter.convert_batch([py_path, js_path]) + + assert result[py_path] == "print('Python')" + assert result[js_path] == "console.log('JavaScript');" + + @pytest.mark.asyncio + async def test_convert_nonexistent_file(self, converter): + result = await converter.convert_batch(["/nonexistent/code.py"]) + assert "/nonexistent/code.py" in result + assert result["/nonexistent/code.py"] is None diff --git a/backend/tests/unit/domains/converters/test_html.py b/backend/tests/unit/domains/converters/test_html.py new file mode 100644 index 000000000..d5a04e1c3 --- /dev/null +++ b/backend/tests/unit/domains/converters/test_html.py @@ -0,0 +1,76 @@ +"""Unit tests for HtmlConverter encoding validation.""" + +import os +import tempfile + +import pytest + +from airweave.domains.converters.html import HtmlConverter + + +@pytest.fixture +def converter(): + return HtmlConverter() + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +class TestHtmlConverterEncodingValidation: + + @pytest.mark.asyncio + async def test_convert_clean_html(self, converter, temp_dir): + file_path = os.path.join(temp_dir, "clean.html") + html = """ + +Test Page + +

Hello World

+

This is a test paragraph.

+ + +""" + with open(file_path, "w", encoding="utf-8") as f: + f.write(html) + + result = await converter.convert_batch([file_path]) + + assert file_path in result + assert result[file_path] is not None + assert "Hello World" in result[file_path] + assert "test paragraph" in result[file_path] + + @pytest.mark.asyncio + async def test_convert_empty_html(self, converter, temp_dir): + file_path = os.path.join(temp_dir, "empty.html") + with open(file_path, "w", encoding="utf-8") as f: + f.write("") + + result = await converter.convert_batch([file_path]) + + assert file_path in result + assert result[file_path] is None + + @pytest.mark.asyncio + async def test_convert_batch_multiple_html_files(self, converter, temp_dir): + html1_path = os.path.join(temp_dir, "page1.html") + with open(html1_path, "w", encoding="utf-8") as f: + f.write("

Page 1

") + + html2_path = os.path.join(temp_dir, "page2.html") + with open(html2_path, "w", encoding="utf-8") as f: + f.write("

Page 2

") + + result = await converter.convert_batch([html1_path, html2_path]) + + assert html1_path in result + assert html2_path in result + + @pytest.mark.asyncio + async def test_convert_nonexistent_file(self, converter): + result = await converter.convert_batch(["/nonexistent/page.html"]) + assert "/nonexistent/page.html" in result + assert result["/nonexistent/page.html"] is None diff --git a/backend/tests/unit/domains/converters/test_registry.py b/backend/tests/unit/domains/converters/test_registry.py new file mode 100644 index 000000000..6d0dc102b --- /dev/null +++ b/backend/tests/unit/domains/converters/test_registry.py @@ -0,0 +1,36 @@ +"""Tests for ConverterRegistry.""" + +from airweave.domains.converters.code import CodeConverter +from airweave.domains.converters.html import HtmlConverter +from airweave.domains.converters.pdf import PdfConverter +from airweave.domains.converters.registry import ConverterRegistry +from airweave.domains.converters.txt import TxtConverter +from airweave.domains.converters.web import WebConverter +from airweave.domains.converters.xlsx import XlsxConverter + + +class TestConverterRegistry: + def test_builds_without_ocr(self): + registry = ConverterRegistry(ocr_provider=None) + assert registry.for_extension(".pdf") is not None + assert isinstance(registry.for_extension(".pdf"), PdfConverter) + + def test_extension_mapping(self): + registry = ConverterRegistry(ocr_provider=None) + assert isinstance(registry.for_extension(".html"), HtmlConverter) + assert isinstance(registry.for_extension(".txt"), TxtConverter) + assert isinstance(registry.for_extension(".xlsx"), XlsxConverter) + assert isinstance(registry.for_extension(".py"), CodeConverter) + + def test_unknown_extension_returns_none(self): + registry = ConverterRegistry(ocr_provider=None) + assert registry.for_extension(".unknown") is None + + def test_for_web_returns_web_converter(self): + registry = ConverterRegistry(ocr_provider=None) + assert isinstance(registry.for_web(), WebConverter) + + def test_image_extensions_use_ocr_provider(self): + registry = ConverterRegistry(ocr_provider=None) + assert registry.for_extension(".jpg") is None + assert registry.for_extension(".png") is None diff --git a/backend/tests/unit/platform/converters/test_txt_converter.py b/backend/tests/unit/domains/converters/test_txt.py similarity index 58% rename from backend/tests/unit/platform/converters/test_txt_converter.py rename to backend/tests/unit/domains/converters/test_txt.py index e7946bf14..e684009e9 100644 --- a/backend/tests/unit/platform/converters/test_txt_converter.py +++ b/backend/tests/unit/domains/converters/test_txt.py @@ -1,33 +1,29 @@ """Unit tests for TxtConverter encoding validation.""" import os -import pytest import tempfile -from pathlib import Path -from airweave.platform.converters.txt_converter import TxtConverter +import pytest + +from airweave.domains.converters.txt import TxtConverter from airweave.domains.sync_pipeline.exceptions import EntityProcessingError @pytest.fixture def converter(): - """Create TxtConverter instance.""" return TxtConverter() @pytest.fixture def temp_dir(): - """Create temporary directory for test files.""" with tempfile.TemporaryDirectory() as tmpdir: yield tmpdir class TestTxtConverterEncodingValidation: - """Test TxtConverter encoding detection and validation.""" @pytest.mark.asyncio async def test_convert_clean_utf8_text(self, converter, temp_dir): - """Test conversion of clean UTF-8 text.""" file_path = os.path.join(temp_dir, "clean.txt") with open(file_path, "w", encoding="utf-8") as f: f.write("Hello world! This is clean UTF-8 text.") @@ -39,7 +35,6 @@ async def test_convert_clean_utf8_text(self, converter, temp_dir): @pytest.mark.asyncio async def test_convert_unicode_text(self, converter, temp_dir): - """Test conversion of Unicode text.""" file_path = os.path.join(temp_dir, "unicode.txt") with open(file_path, "w", encoding="utf-8") as f: f.write("Hello 世界 🌍 こんにちは") @@ -51,25 +46,16 @@ async def test_convert_unicode_text(self, converter, temp_dir): @pytest.mark.asyncio async def test_convert_corrupted_text_file(self, converter, temp_dir): - """Test rejection of file with excessive replacement characters.""" file_path = os.path.join(temp_dir, "corrupted.txt") - # Write truly invalid UTF-8 sequences (incomplete multi-byte sequences) - # These will produce replacement characters in UTF-8 decoding with open(file_path, "wb") as f: - # Write many incomplete UTF-8 sequences for _ in range(10000): - f.write(b"\xc0\x80") # Invalid/overlong UTF-8 sequence + f.write(b"\xc0\x80") result = await converter.convert_batch([file_path]) - - # Should fail due to excessive replacement characters or raise EntityProcessingError assert file_path in result - # May be None (rejected) or may have decoded with chardet - # The important thing is it doesn't crash @pytest.mark.asyncio async def test_convert_empty_file(self, converter, temp_dir): - """Test conversion of empty file.""" file_path = os.path.join(temp_dir, "empty.txt") with open(file_path, "w", encoding="utf-8") as f: f.write("") @@ -81,7 +67,6 @@ async def test_convert_empty_file(self, converter, temp_dir): @pytest.mark.asyncio async def test_convert_json_clean(self, converter, temp_dir): - """Test JSON conversion with clean data.""" file_path = os.path.join(temp_dir, "clean.json") with open(file_path, "w", encoding="utf-8") as f: f.write('{"name": "test", "value": 123}') @@ -91,25 +76,20 @@ async def test_convert_json_clean(self, converter, temp_dir): assert file_path in result assert result[file_path] is not None assert "name" in result[file_path] - assert "test" in result[file_path] @pytest.mark.asyncio async def test_convert_json_with_corruption(self, converter, temp_dir): - """Test JSON with invalid syntax fails gracefully.""" file_path = os.path.join(temp_dir, "corrupted.json") - # Write invalid JSON (will fail JSON parsing, not encoding) with open(file_path, "w", encoding="utf-8") as f: f.write('{"name": invalid}') result = await converter.convert_batch([file_path]) - # Should fail due to invalid JSON syntax assert file_path in result assert result[file_path] is None @pytest.mark.asyncio async def test_convert_xml_clean(self, converter, temp_dir): - """Test XML conversion with clean data.""" file_path = os.path.join(temp_dir, "clean.xml") with open(file_path, "w", encoding="utf-8") as f: f.write('test') @@ -122,7 +102,6 @@ async def test_convert_xml_clean(self, converter, temp_dir): @pytest.mark.asyncio async def test_convert_batch_mixed_files(self, converter, temp_dir): - """Test batch conversion with mix of clean and empty files.""" clean_path = os.path.join(temp_dir, "clean.txt") with open(clean_path, "w", encoding="utf-8") as f: f.write("Clean text") @@ -133,43 +112,20 @@ async def test_convert_batch_mixed_files(self, converter, temp_dir): result = await converter.convert_batch([clean_path, empty_path]) - # Clean file should succeed assert result[clean_path] == "Clean text" - # Empty file should return None assert result[empty_path] is None - @pytest.mark.asyncio - async def test_convert_latin1_encoding(self, converter, temp_dir): - """Test conversion of Latin-1 encoded file.""" - file_path = os.path.join(temp_dir, "latin1.txt") - # Write Latin-1 text with special characters - text = "Café résumé naïve" - with open(file_path, "wb") as f: - f.write(text.encode("latin-1")) - - result = await converter.convert_batch([file_path]) - - # Should detect encoding or handle gracefully - assert file_path in result - # Result should either be correct or None (if chardet not available) - if result[file_path] is not None: - assert len(result[file_path]) > 0 - class TestTxtConverterEdgeCases: - """Test edge cases in TxtConverter.""" @pytest.mark.asyncio async def test_convert_nonexistent_file(self, converter): - """Test conversion of nonexistent file.""" result = await converter.convert_batch(["/nonexistent/file.txt"]) - assert "/nonexistent/file.txt" in result assert result["/nonexistent/file.txt"] is None @pytest.mark.asyncio async def test_convert_whitespace_only_file(self, converter, temp_dir): - """Test conversion of file with only whitespace.""" file_path = os.path.join(temp_dir, "whitespace.txt") with open(file_path, "w", encoding="utf-8") as f: f.write(" \n\n \t\t ") @@ -178,30 +134,3 @@ async def test_convert_whitespace_only_file(self, converter, temp_dir): assert file_path in result assert result[file_path] is None - - @pytest.mark.asyncio - async def test_convert_large_clean_file(self, converter, temp_dir): - """Test conversion of large clean file.""" - file_path = os.path.join(temp_dir, "large.txt") - # Create 1MB of clean text - large_text = "Hello world! " * 100000 - with open(file_path, "w", encoding="utf-8") as f: - f.write(large_text) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert len(result[file_path]) > 1000000 - - @pytest.mark.asyncio - async def test_convert_json_invalid_syntax(self, converter, temp_dir): - """Test JSON with invalid syntax.""" - file_path = os.path.join(temp_dir, "invalid.json") - with open(file_path, "w", encoding="utf-8") as f: - f.write('{"invalid": }') - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is None diff --git a/backend/tests/unit/platform/converters/__init__.py b/backend/tests/unit/platform/converters/__init__.py deleted file mode 100644 index 17689ece7..000000000 --- a/backend/tests/unit/platform/converters/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Tests for platform converters.""" - diff --git a/backend/tests/unit/platform/converters/test_code_converter.py b/backend/tests/unit/platform/converters/test_code_converter.py deleted file mode 100644 index 04d657380..000000000 --- a/backend/tests/unit/platform/converters/test_code_converter.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Unit tests for CodeConverter encoding validation.""" - -import os -import pytest -import tempfile - -from airweave.platform.converters.code_converter import CodeConverter - - -@pytest.fixture -def converter(): - """Create CodeConverter instance.""" - return CodeConverter() - - -@pytest.fixture -def temp_dir(): - """Create temporary directory for test files.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - -class TestCodeConverterEncodingValidation: - """Test CodeConverter encoding detection and validation.""" - - @pytest.mark.asyncio - async def test_convert_clean_python_code(self, converter, temp_dir): - """Test conversion of clean Python code.""" - file_path = os.path.join(temp_dir, "clean.py") - code = """def hello_world(): - print("Hello, world!") - return True -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] == code - - @pytest.mark.asyncio - async def test_convert_code_with_unicode_comments(self, converter, temp_dir): - """Test conversion of code with Unicode comments.""" - file_path = os.path.join(temp_dir, "unicode.py") - code = """# 这是中文注释 - This is a Chinese comment -def hello(): - # Café résumé 🎉 - return "Hello" -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] == code - assert "中文" in result[file_path] - assert "🎉" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_binary_file_as_code(self, converter, temp_dir): - """Test handling of file with null bytes (clearly binary).""" - file_path = os.path.join(temp_dir, "binary.py") - # Write data with null bytes (clearly not text) - with open(file_path, "wb") as f: - f.write(b"some_code = 1\x00\x00\x00" * 100) - - result = await converter.convert_batch([file_path]) - - # Should reject due to null bytes or handle gracefully - assert file_path in result - # May pass if chardet handles it or fail - either is OK - - @pytest.mark.asyncio - async def test_convert_empty_code_file(self, converter, temp_dir): - """Test conversion of empty code file.""" - file_path = os.path.join(temp_dir, "empty.py") - with open(file_path, "w", encoding="utf-8") as f: - f.write("") - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is None - - @pytest.mark.asyncio - async def test_convert_whitespace_only_code(self, converter, temp_dir): - """Test conversion of code file with only whitespace.""" - file_path = os.path.join(temp_dir, "whitespace.py") - with open(file_path, "w", encoding="utf-8") as f: - f.write(" \n\n \t\t ") - - result = await converter.convert_batch([file_path]) - - assert file_path in result - # Whitespace-only files should return None (no meaningful content) - assert result[file_path] is None or (result[file_path] and not result[file_path].strip()) - - @pytest.mark.asyncio - async def test_convert_javascript_code(self, converter, temp_dir): - """Test conversion of JavaScript code.""" - file_path = os.path.join(temp_dir, "app.js") - code = """function greet(name) { - console.log(`Hello, ${name}!`); -} - -export default greet; -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] == code - assert "function greet" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_cpp_code(self, converter, temp_dir): - """Test conversion of C++ code.""" - file_path = os.path.join(temp_dir, "main.cpp") - code = """#include - -int main() { - std::cout << "Hello World!" << std::endl; - return 0; -} -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] == code - assert "#include" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_batch_multiple_code_files(self, converter, temp_dir): - """Test batch conversion of multiple code files.""" - py_path = os.path.join(temp_dir, "script.py") - with open(py_path, "w", encoding="utf-8") as f: - f.write("print('Python')") - - js_path = os.path.join(temp_dir, "script.js") - with open(js_path, "w", encoding="utf-8") as f: - f.write("console.log('JavaScript');") - - result = await converter.convert_batch([py_path, js_path]) - - assert result[py_path] == "print('Python')" - assert result[js_path] == "console.log('JavaScript');" - - @pytest.mark.asyncio - async def test_convert_code_with_comments(self, converter, temp_dir): - """Test conversion of code with comments.""" - file_path = os.path.join(temp_dir, "commented.py") - code = """def hello(): - # This is a comment - # Another comment - return True -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - # Should successfully convert - assert file_path in result - assert result[file_path] == code - - @pytest.mark.asyncio - async def test_convert_large_code_file(self, converter, temp_dir): - """Test conversion of large code file.""" - file_path = os.path.join(temp_dir, "large.py") - # Generate large code file - lines = [] - for i in range(10000): - lines.append(f"def function_{i}():\n") - lines.append(f" return {i}\n") - lines.append("\n") - - with open(file_path, "w", encoding="utf-8") as f: - f.writelines(lines) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert "function_0" in result[file_path] - assert "function_9999" in result[file_path] - - -class TestCodeConverterEdgeCases: - """Test edge cases in CodeConverter.""" - - @pytest.mark.asyncio - async def test_convert_nonexistent_file(self, converter): - """Test conversion of nonexistent file.""" - result = await converter.convert_batch(["/nonexistent/code.py"]) - - assert "/nonexistent/code.py" in result - assert result["/nonexistent/code.py"] is None - - @pytest.mark.asyncio - async def test_convert_code_with_long_lines(self, converter, temp_dir): - """Test conversion of code with very long lines.""" - file_path = os.path.join(temp_dir, "long_lines.py") - # Create a file with a very long line - long_string = "x" * 10000 - code = f'long_var = "{long_string}"\n' - - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert len(result[file_path]) > 10000 - - @pytest.mark.asyncio - async def test_convert_code_with_special_characters(self, converter, temp_dir): - """Test conversion of code with special characters in strings.""" - file_path = os.path.join(temp_dir, "special.py") - code = r'''def test(): - s1 = "Line with \n newline" - s2 = "Tab with \t tab" - s3 = 'Quote with \' quote' - return True -''' - with open(file_path, "w", encoding="utf-8") as f: - f.write(code) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] == code - diff --git a/backend/tests/unit/platform/converters/test_html_converter.py b/backend/tests/unit/platform/converters/test_html_converter.py deleted file mode 100644 index ee5e3ac7b..000000000 --- a/backend/tests/unit/platform/converters/test_html_converter.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Unit tests for HtmlConverter encoding validation.""" - -import os -import pytest -import tempfile - -from airweave.platform.converters.html_converter import HtmlConverter - - -@pytest.fixture -def converter(): - """Create HtmlConverter instance.""" - return HtmlConverter() - - -@pytest.fixture -def temp_dir(): - """Create temporary directory for test files.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - -class TestHtmlConverterEncodingValidation: - """Test HtmlConverter encoding detection and validation.""" - - @pytest.mark.asyncio - async def test_convert_clean_html(self, converter, temp_dir): - """Test conversion of clean HTML.""" - file_path = os.path.join(temp_dir, "clean.html") - html = """ - -Test Page - -

Hello World

-

This is a test paragraph.

- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - # Check that HTML was converted (markdown should have text) - assert "Hello World" in result[file_path] - assert "test paragraph" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_html_with_unicode(self, converter, temp_dir): - """Test conversion of HTML with Unicode content.""" - file_path = os.path.join(temp_dir, "unicode.html") - html = """ - - -

Unicode: 世界 🌍 Café

- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert "世界" in result[file_path] - assert "🌍" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_html_with_invalid_tags(self, converter, temp_dir): - """Test HTML with broken/invalid tags.""" - file_path = os.path.join(temp_dir, "invalid.html") - html = """ - -

Broken tag

Another paragraph

- -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - # Should handle malformed HTML gracefully - assert file_path in result - # May succeed or fail depending on html-to-markdown tolerance - - @pytest.mark.asyncio - async def test_convert_empty_html(self, converter, temp_dir): - """Test conversion of empty HTML file.""" - file_path = os.path.join(temp_dir, "empty.html") - with open(file_path, "w", encoding="utf-8") as f: - f.write("") - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is None - - @pytest.mark.asyncio - async def test_convert_html_with_special_entities(self, converter, temp_dir): - """Test conversion of HTML with special entities.""" - file_path = os.path.join(temp_dir, "entities.html") - html = """ - - -

<div> & "quotes" © 2024

- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - - @pytest.mark.asyncio - async def test_convert_html_with_nested_structure(self, converter, temp_dir): - """Test conversion of HTML with nested structure.""" - file_path = os.path.join(temp_dir, "nested.html") - html = """ - - -
-
-

Main Title

-
-
-

Section Title

-

Section content

-
-
- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert "Main Title" in result[file_path] - assert "Section Title" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_html_with_links(self, converter, temp_dir): - """Test conversion of HTML with links.""" - file_path = os.path.join(temp_dir, "links.html") - html = """ - - -

Visit our website for more.

- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is not None - assert "website" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_malformed_html(self, converter, temp_dir): - """Test conversion of malformed HTML.""" - file_path = os.path.join(temp_dir, "malformed.html") - html = """ - -

Unclosed paragraph -

Unclosed div - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - # Should still convert malformed HTML - assert file_path in result - # Result depends on html-to-markdown behavior - # May succeed or fail gracefully - - @pytest.mark.asyncio - async def test_convert_batch_multiple_html_files(self, converter, temp_dir): - """Test batch conversion of multiple HTML files.""" - html1_path = os.path.join(temp_dir, "page1.html") - with open(html1_path, "w", encoding="utf-8") as f: - f.write("

Page 1

") - - html2_path = os.path.join(temp_dir, "page2.html") - with open(html2_path, "w", encoding="utf-8") as f: - f.write("

Page 2

") - - result = await converter.convert_batch([html1_path, html2_path]) - - assert html1_path in result - assert html2_path in result - if result[html1_path]: - assert "Page 1" in result[html1_path] - if result[html2_path]: - assert "Page 2" in result[html2_path] - - @pytest.mark.asyncio - async def test_convert_html_with_scripts_and_styles(self, converter, temp_dir): - """Test conversion of HTML with script and style tags.""" - file_path = os.path.join(temp_dir, "with_scripts.html") - html = """ - - - - - - -

Visible content

- - -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - # Should extract visible content, scripts/styles may be filtered - if result[file_path]: - assert "Visible content" in result[file_path] - - -class TestHtmlConverterEdgeCases: - """Test edge cases in HtmlConverter.""" - - @pytest.mark.asyncio - async def test_convert_nonexistent_file(self, converter): - """Test conversion of nonexistent file.""" - result = await converter.convert_batch(["/nonexistent/page.html"]) - - assert "/nonexistent/page.html" in result - assert result["/nonexistent/page.html"] is None - - @pytest.mark.asyncio - async def test_convert_whitespace_only_html(self, converter, temp_dir): - """Test conversion of HTML with only whitespace.""" - file_path = os.path.join(temp_dir, "whitespace.html") - with open(file_path, "w", encoding="utf-8") as f: - f.write(" \n\n ") - - result = await converter.convert_batch([file_path]) - - assert file_path in result - assert result[file_path] is None - - @pytest.mark.asyncio - async def test_convert_large_html_file(self, converter, temp_dir): - """Test conversion of large HTML file.""" - file_path = os.path.join(temp_dir, "large.html") - # Generate large HTML - html_parts = [""] - for i in range(1000): - html_parts.append(f"

Paragraph {i}

") - html_parts.append("") - - with open(file_path, "w", encoding="utf-8") as f: - f.write("".join(html_parts)) - - result = await converter.convert_batch([file_path]) - - assert file_path in result - # Should handle large files - if result[file_path]: - assert "Paragraph 0" in result[file_path] - - @pytest.mark.asyncio - async def test_convert_html_with_missing_closing_tags(self, converter, temp_dir): - """Test HTML with missing closing tags.""" - file_path = os.path.join(temp_dir, "unclosed.html") - html = """ - -

Paragraph without closing tag -

Another div - Nested span -""" - with open(file_path, "w", encoding="utf-8") as f: - f.write(html) - - result = await converter.convert_batch([file_path]) - - # Should handle unclosed tags gracefully - assert file_path in result - # html-to-markdown is usually tolerant of malformed HTML - diff --git a/backend/tests/unit/platform/converters/test_init_converters.py b/backend/tests/unit/platform/converters/test_init_converters.py deleted file mode 100644 index 4819fd47a..000000000 --- a/backend/tests/unit/platform/converters/test_init_converters.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Tests for initialize_converters with optional OCR provider.""" - -import airweave.platform.converters as converters_mod -from airweave.adapters.ocr.fake import FakeOcrProvider -from airweave.platform.converters import PdfConverter, initialize_converters - - -class TestInitializeConverters: - def setup_method(self): - """Reset singleton state before each test.""" - converters_mod._singletons = None - - def teardown_method(self): - """Clean up singleton state after each test.""" - converters_mod._singletons = None - - def test_accepts_none_ocr_provider(self): - """When OCR is None, document converters work without OCR.""" - initialize_converters(None) - - assert converters_mod._singletons is not None - assert converters_mod._singletons["mistral_converter"] is None - assert converters_mod._singletons["img_converter"] is None - assert isinstance(converters_mod._singletons["pdf_converter"], PdfConverter) - - def test_accepts_fake_ocr_provider(self): - """When OCR is provided, it is wired into document converters.""" - fake = FakeOcrProvider() - initialize_converters(fake) - - assert converters_mod._singletons is not None - assert converters_mod._singletons["mistral_converter"] is fake - assert converters_mod._singletons["img_converter"] is fake From ad19e04c175fef4dc6ffa0398a539ab64a08380e Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Fri, 13 Mar 2026 12:37:12 -0700 Subject: [PATCH 12/13] fix: make converter_registry optional for metadata-only callers (slack federated search) --- .../domains/sync_pipeline/pipeline/text_builder.py | 9 +++++++-- backend/airweave/platform/sources/slack.py | 6 ++---- .../unit/platform/temporal/test_worker_ocr_guard.py | 3 --- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py index 84788abdb..9da424d22 100644 --- a/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py +++ b/backend/airweave/domains/sync_pipeline/pipeline/text_builder.py @@ -27,8 +27,13 @@ class TextualRepresentationBuilder: DEFAULT_CONVERTER_BATCH_SIZE = 10 - def __init__(self, converter_registry: ConverterRegistryProtocol) -> None: - """Initialize with a converter registry for routing entities to converters.""" + def __init__(self, converter_registry: Optional[ConverterRegistryProtocol] = None) -> None: + """Initialize with a converter registry for routing entities to converters. + + The registry is optional for callers that only need metadata building + (e.g. federated search sources). Converter routing will fail at runtime + if the registry is None and a file/web entity is encountered. + """ self._registry = converter_registry # ------------------------------------------------------------------------------------ diff --git a/backend/airweave/platform/sources/slack.py b/backend/airweave/platform/sources/slack.py index f46e2701d..9cb6aaa0c 100644 --- a/backend/airweave/platform/sources/slack.py +++ b/backend/airweave/platform/sources/slack.py @@ -8,7 +8,7 @@ from airweave.core.exceptions import TokenRefreshError from airweave.core.shared_models import RateLimitLevel -from airweave.domains.sync_pipeline.pipeline.text_builder import text_builder +from airweave.domains.sync_pipeline.pipeline.text_builder import TextualRepresentationBuilder from airweave.platform.configs.auth import SlackAuthConfig from airweave.platform.configs.config import SlackConfig from airweave.platform.decorators import source @@ -326,9 +326,7 @@ async def _create_message_entity(self, message: Dict[str, Any]) -> Optional[Slac web_url_value=message.get("permalink"), ) - # Build textual representation using shared utility - # (normally built by TextualRepresentationBuilder in sync pipeline) - entity.textual_representation = text_builder.build_metadata_section( + entity.textual_representation = TextualRepresentationBuilder().build_metadata_section( entity=entity, source_name="slack", ) diff --git a/backend/tests/unit/platform/temporal/test_worker_ocr_guard.py b/backend/tests/unit/platform/temporal/test_worker_ocr_guard.py index ae8546718..6be117e8e 100644 --- a/backend/tests/unit/platform/temporal/test_worker_ocr_guard.py +++ b/backend/tests/unit/platform/temporal/test_worker_ocr_guard.py @@ -41,9 +41,6 @@ async def test_passes_guard_when_ocr_present(test_container): container_mod, "initialize_container", ), - patch( - "airweave.platform.converters.initialize_converters", - ), patch( "airweave.platform.temporal.worker.WorkerConfig.from_settings", side_effect=SystemExit(99), From 9afe7e38baefd2ecd1f8fda81bbd5cbd020b6989 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Fri, 13 Mar 2026 13:46:48 -0700 Subject: [PATCH 13/13] refactor: decompose sync_pipeline, extract ACL to domain, consolidate OCR - Move ACL pipeline/resolver/dispatcher/handler/tracker/actions/schemas from sync_pipeline and platform/access_control into domains/access_control - Reorganize entity pipeline into sync_pipeline/entity/ sub-package - Consolidate OCR from core/protocols, adapters/ocr, platform/ocr into domains/ocr; eliminate MistralOcrAdapter passthrough - Convert TYPE_CHECKING guards to regular imports in ACL protocols - Update all importers across sources, tests, container factory --- backend/airweave/adapters/ocr/__init__.py | 13 --- backend/airweave/adapters/ocr/mistral.py | 47 --------- backend/airweave/core/container/factory.py | 11 ++- backend/airweave/core/protocols/__init__.py | 2 +- .../actions.py} | 2 +- .../airweave/domains/access_control/broker.py | 2 +- .../dispatcher.py} | 4 +- .../domains/access_control/fakes/broker.py | 2 +- .../membership_tracker.py} | 0 .../pipeline.py} | 8 +- .../postgres_handler.py} | 10 +- .../domains/access_control/protocols.py | 95 ++++++++++++++++++- .../resolver.py} | 4 +- .../access_control/schemas.py | 0 .../domains/access_control/tests/__init__.py | 1 + .../tests/test_membership_tracker.py} | 2 +- .../tests/test_pipeline.py} | 6 +- backend/airweave/domains/converters/_base.py | 2 +- .../airweave/domains/converters/registry.py | 2 +- backend/airweave/domains/ocr/__init__.py | 0 .../{adapters => domains}/ocr/docling.py | 0 .../airweave/domains/ocr/fakes/__init__.py | 0 .../fake.py => domains/ocr/fakes/provider.py} | 0 .../{adapters => domains}/ocr/fallback.py | 3 +- .../airweave/domains/ocr/mistral/__init__.py | 1 + .../ocr/mistral/compressor.py | 2 +- .../ocr/mistral/converter.py | 10 +- .../ocr/mistral/models.py | 0 .../ocr/mistral/ocr_client.py | 4 +- .../ocr/mistral/splitters.py | 0 .../ocr.py => domains/ocr/protocols.py} | 0 .../domains/sync_pipeline/entity/__init__.py | 0 .../entity_actions.py => entity/actions.py} | 0 .../dispatcher.py} | 4 +- .../dispatcher_builder.py} | 10 +- .../sync_pipeline/entity/handlers/__init__.py | 0 .../{ => entity}/handlers/arf.py | 6 +- .../{ => entity}/handlers/destination.py | 8 +- .../handlers/postgres.py} | 6 +- .../{ => entity}/handlers/protocol.py | 79 +-------------- .../pipeline.py} | 2 +- .../resolver.py} | 4 +- .../airweave/domains/sync_pipeline/factory.py | 16 ++-- .../sync_pipeline/handlers/__init__.py | 10 -- .../domains/sync_pipeline/orchestrator.py | 4 +- .../domains/sync_pipeline/protocols.py | 28 +----- .../tests/test_destination_handler.py | 4 +- .../tests/test_entity_action_resolver.py | 4 +- .../tests/test_entity_pipeline.py | 2 +- .../domains/sync_pipeline/types/__init__.py | 1 - .../platform/access_control/__init__.py | 5 - backend/airweave/platform/ocr/__init__.py | 10 -- .../airweave/platform/ocr/mistral/__init__.py | 8 -- backend/airweave/platform/sources/_base.py | 2 +- .../platform/sources/sharepoint2019v2/ldap.py | 2 +- .../sources/sharepoint2019v2/source.py | 2 +- .../sources/sharepoint_online/graph_groups.py | 2 +- .../sources/sharepoint_online/source.py | 2 +- backend/conftest.py | 2 +- .../unit/api/test_admin_user_principals.py | 2 +- .../unit/core/test_ocr_provider_factory.py | 12 +-- .../domains/access_control/test_broker.py | 2 +- backend/tests/unit/domains/ocr/__init__.py | 0 .../ocr/test_docling.py} | 10 +- .../ocr/test_fallback.py} | 6 +- .../sources/test_sharepoint2019v2_dirsync.py | 2 +- .../operations/test_access_control_filter.py | 2 +- 67 files changed, 196 insertions(+), 296 deletions(-) delete mode 100644 backend/airweave/adapters/ocr/__init__.py delete mode 100644 backend/airweave/adapters/ocr/mistral.py rename backend/airweave/domains/{sync_pipeline/types/access_control_actions.py => access_control/actions.py} (98%) rename backend/airweave/domains/{sync_pipeline/access_control_dispatcher.py => access_control/dispatcher.py} (93%) rename backend/airweave/domains/{sync_pipeline/pipeline/acl_membership_tracker.py => access_control/membership_tracker.py} (100%) rename backend/airweave/domains/{sync_pipeline/access_control_pipeline.py => access_control/pipeline.py} (98%) rename backend/airweave/domains/{sync_pipeline/handlers/access_control_postgres.py => access_control/postgres_handler.py} (95%) rename backend/airweave/domains/{sync_pipeline/access_control_resolver.py => access_control/resolver.py} (93%) rename backend/airweave/{platform => domains}/access_control/schemas.py (100%) create mode 100644 backend/airweave/domains/access_control/tests/__init__.py rename backend/airweave/domains/{sync_pipeline/tests/test_acl_membership_tracker.py => access_control/tests/test_membership_tracker.py} (99%) rename backend/airweave/domains/{sync_pipeline/tests/test_acl_reconciliation.py => access_control/tests/test_pipeline.py} (98%) create mode 100644 backend/airweave/domains/ocr/__init__.py rename backend/airweave/{adapters => domains}/ocr/docling.py (100%) create mode 100644 backend/airweave/domains/ocr/fakes/__init__.py rename backend/airweave/{adapters/ocr/fake.py => domains/ocr/fakes/provider.py} (100%) rename backend/airweave/{adapters => domains}/ocr/fallback.py (96%) create mode 100644 backend/airweave/domains/ocr/mistral/__init__.py rename backend/airweave/{platform => domains}/ocr/mistral/compressor.py (98%) rename backend/airweave/{platform => domains}/ocr/mistral/converter.py (98%) rename backend/airweave/{platform => domains}/ocr/mistral/models.py (100%) rename backend/airweave/{platform => domains}/ocr/mistral/ocr_client.py (99%) rename backend/airweave/{platform => domains}/ocr/mistral/splitters.py (100%) rename backend/airweave/{core/protocols/ocr.py => domains/ocr/protocols.py} (100%) create mode 100644 backend/airweave/domains/sync_pipeline/entity/__init__.py rename backend/airweave/domains/sync_pipeline/{types/entity_actions.py => entity/actions.py} (100%) rename backend/airweave/domains/sync_pipeline/{entity_action_dispatcher.py => entity/dispatcher.py} (97%) rename backend/airweave/domains/sync_pipeline/{entity_dispatcher_builder.py => entity/dispatcher_builder.py} (92%) create mode 100644 backend/airweave/domains/sync_pipeline/entity/handlers/__init__.py rename backend/airweave/domains/sync_pipeline/{ => entity}/handlers/arf.py (97%) rename backend/airweave/domains/sync_pipeline/{ => entity}/handlers/destination.py (98%) rename backend/airweave/domains/sync_pipeline/{handlers/entity_postgres.py => entity/handlers/postgres.py} (98%) rename backend/airweave/domains/sync_pipeline/{ => entity}/handlers/protocol.py (51%) rename backend/airweave/domains/sync_pipeline/{entity_pipeline.py => entity/pipeline.py} (99%) rename backend/airweave/domains/sync_pipeline/{entity_action_resolver.py => entity/resolver.py} (99%) delete mode 100644 backend/airweave/domains/sync_pipeline/handlers/__init__.py delete mode 100644 backend/airweave/domains/sync_pipeline/types/__init__.py delete mode 100644 backend/airweave/platform/access_control/__init__.py delete mode 100644 backend/airweave/platform/ocr/__init__.py delete mode 100644 backend/airweave/platform/ocr/mistral/__init__.py create mode 100644 backend/tests/unit/domains/ocr/__init__.py rename backend/tests/unit/{adapters/test_docling_ocr.py => domains/ocr/test_docling.py} (92%) rename backend/tests/unit/{adapters/test_fallback_ocr.py => domains/ocr/test_fallback.py} (95%) diff --git a/backend/airweave/adapters/ocr/__init__.py b/backend/airweave/adapters/ocr/__init__.py deleted file mode 100644 index d79a65926..000000000 --- a/backend/airweave/adapters/ocr/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""OCR adapters.""" - -from airweave.adapters.ocr.docling import DoclingOcrAdapter -from airweave.adapters.ocr.fake import FakeOcrProvider -from airweave.adapters.ocr.fallback import FallbackOcrProvider -from airweave.adapters.ocr.mistral import MistralOcrAdapter - -__all__ = [ - "DoclingOcrAdapter", - "FakeOcrProvider", - "FallbackOcrProvider", - "MistralOcrAdapter", -] diff --git a/backend/airweave/adapters/ocr/mistral.py b/backend/airweave/adapters/ocr/mistral.py deleted file mode 100644 index b1e2df4af..000000000 --- a/backend/airweave/adapters/ocr/mistral.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Mistral OCR adapter. - -Thin adapter that delegates to the platform MistralOCR implementation, -exposing it through the adapters layer for dependency injection. - -Satisfies the :class:`~airweave.core.protocols.ocr.OcrProvider` protocol. -""" - -from __future__ import annotations - -from typing import Dict, List, Optional - -from airweave.platform.ocr.mistral.converter import MistralOCR - - -class MistralOcrAdapter: - """Adapter for Mistral OCR. - - Wraps :class:`~airweave.platform.ocr.mistral.converter.MistralOCR` - so callers depend on the adapter layer rather than reaching into platform - internals. - - Usage:: - - ocr: OcrProvider = MistralOcrAdapter() - results = await ocr.convert_batch(["/tmp/doc.pdf"]) - """ - - def __init__(self, concurrency: int = 10) -> None: - """Initialize the adapter. - - Args: - concurrency: Maximum concurrent OCR calls passed to the - underlying Mistral client. - """ - self._impl = MistralOCR(concurrency=concurrency) - - async def convert_batch(self, file_paths: List[str]) -> Dict[str, Optional[str]]: - """Convert files to markdown via Mistral OCR. - - Args: - file_paths: Local file paths (PDF, DOCX, PPTX, JPG, JPEG, PNG). - - Returns: - Mapping of ``file_path -> markdown`` (``None`` on per-file failure). - """ - return await self._impl.convert_batch(file_paths) diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 284a856e2..7cb856289 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -26,9 +26,6 @@ PrometheusHttpMetrics, PrometheusMetricsRenderer, ) -from airweave.adapters.ocr.docling import DoclingOcrAdapter -from airweave.adapters.ocr.fallback import FallbackOcrProvider -from airweave.adapters.ocr.mistral import MistralOcrAdapter from airweave.adapters.pubsub.redis import RedisPubSub from airweave.adapters.webhooks.endpoint_verifier import HttpEndpointVerifier from airweave.adapters.webhooks.svix import SvixAdapter @@ -37,7 +34,7 @@ from airweave.core.health.service import HealthService from airweave.core.logging import logger from airweave.core.metrics_service import PrometheusMetricsService -from airweave.core.protocols import CircuitBreaker, OcrProvider, PubSub +from airweave.core.protocols import CircuitBreaker, PubSub from airweave.core.protocols.event_bus import EventBus from airweave.core.protocols.identity import IdentityProvider from airweave.core.protocols.payment import PaymentGatewayProtocol @@ -80,6 +77,10 @@ OAuthInitSessionRepository, OAuthRedirectSessionRepository, ) +from airweave.domains.ocr.docling import DoclingOcrAdapter +from airweave.domains.ocr.fallback import FallbackOcrProvider +from airweave.domains.ocr.mistral.converter import MistralOCR +from airweave.domains.ocr.protocols import OcrProvider from airweave.domains.organizations.protocols import UserOrganizationRepositoryProtocol from airweave.domains.organizations.repository import OrganizationRepository as OrgRepo from airweave.domains.organizations.repository import UserOrganizationRepository @@ -657,7 +658,7 @@ def _create_ocr_provider( Returns None with a warning when no providers are available. """ try: - mistral_ocr = MistralOcrAdapter() + mistral_ocr = MistralOCR() except Exception as e: logger.error(f"Error creating Mistral OCR adapter: {e}") mistral_ocr = None diff --git a/backend/airweave/core/protocols/__init__.py b/backend/airweave/core/protocols/__init__.py index a94409ff9..c43a0bcfa 100644 --- a/backend/airweave/core/protocols/__init__.py +++ b/backend/airweave/core/protocols/__init__.py @@ -21,7 +21,6 @@ MetricsService, WorkerMetrics, ) -from airweave.core.protocols.ocr import OcrProvider from airweave.core.protocols.payment import PaymentGatewayProtocol from airweave.core.protocols.pubsub import PubSub, PubSubSubscription from airweave.core.protocols.rate_limiter import RateLimiter @@ -32,6 +31,7 @@ WebhookServiceProtocol, ) from airweave.core.protocols.worker_metrics_registry import WorkerMetricsRegistryProtocol +from airweave.domains.ocr.protocols import OcrProvider __all__ = [ "AgenticSearchMetrics", diff --git a/backend/airweave/domains/sync_pipeline/types/access_control_actions.py b/backend/airweave/domains/access_control/actions.py similarity index 98% rename from backend/airweave/domains/sync_pipeline/types/access_control_actions.py rename to backend/airweave/domains/access_control/actions.py index 021f474f0..a23fa4504 100644 --- a/backend/airweave/domains/sync_pipeline/types/access_control_actions.py +++ b/backend/airweave/domains/access_control/actions.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, List if TYPE_CHECKING: - from airweave.platform.access_control.schemas import MembershipTuple + from airweave.domains.access_control.schemas import MembershipTuple # ============================================================================= diff --git a/backend/airweave/domains/access_control/broker.py b/backend/airweave/domains/access_control/broker.py index b0d962d17..a5b46d7e0 100644 --- a/backend/airweave/domains/access_control/broker.py +++ b/backend/airweave/domains/access_control/broker.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol -from airweave.platform.access_control.schemas import AccessContext +from airweave.domains.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl diff --git a/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py b/backend/airweave/domains/access_control/dispatcher.py similarity index 93% rename from backend/airweave/domains/sync_pipeline/access_control_dispatcher.py rename to backend/airweave/domains/access_control/dispatcher.py index 5e31e0a94..35ff60044 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_dispatcher.py +++ b/backend/airweave/domains/access_control/dispatcher.py @@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, List +from airweave.domains.access_control.actions import ACActionBatch from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch if TYPE_CHECKING: + from airweave.domains.access_control.protocols import ACActionHandler from airweave.domains.sync_pipeline.contexts import SyncContext - from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler class ACActionDispatcher: diff --git a/backend/airweave/domains/access_control/fakes/broker.py b/backend/airweave/domains/access_control/fakes/broker.py index 63d23700b..5803b6caa 100644 --- a/backend/airweave/domains/access_control/fakes/broker.py +++ b/backend/airweave/domains/access_control/fakes/broker.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession -from airweave.platform.access_control.schemas import AccessContext +from airweave.domains.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl diff --git a/backend/airweave/domains/sync_pipeline/pipeline/acl_membership_tracker.py b/backend/airweave/domains/access_control/membership_tracker.py similarity index 100% rename from backend/airweave/domains/sync_pipeline/pipeline/acl_membership_tracker.py rename to backend/airweave/domains/access_control/membership_tracker.py diff --git a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py b/backend/airweave/domains/access_control/pipeline.py similarity index 98% rename from backend/airweave/domains/sync_pipeline/access_control_pipeline.py rename to backend/airweave/domains/access_control/pipeline.py index a3951148c..572b681ed 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_pipeline.py +++ b/backend/airweave/domains/access_control/pipeline.py @@ -12,13 +12,13 @@ from typing import TYPE_CHECKING, List, Set, Tuple from airweave.db.session import get_db_context -from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol -from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker -from airweave.domains.sync_pipeline.protocols import ( +from airweave.domains.access_control.membership_tracker import ACLMembershipTracker +from airweave.domains.access_control.protocols import ( ACActionDispatcherProtocol, ACActionResolverProtocol, + AccessControlMembershipRepositoryProtocol, ) -from airweave.platform.access_control.schemas import ( +from airweave.domains.access_control.schemas import ( ACLChangeType, MembershipTuple, ) diff --git a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py b/backend/airweave/domains/access_control/postgres_handler.py similarity index 95% rename from backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py rename to backend/airweave/domains/access_control/postgres_handler.py index 6e3fc397c..5a41483ac 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/access_control_postgres.py +++ b/backend/airweave/domains/access_control/postgres_handler.py @@ -7,16 +7,18 @@ from typing import TYPE_CHECKING, List from airweave.db.session import get_db_context -from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler -from airweave.domains.sync_pipeline.types.access_control_actions import ( +from airweave.domains.access_control.actions import ( ACActionBatch, ACDeleteAction, ACInsertAction, ACUpdateAction, ACUpsertAction, ) +from airweave.domains.access_control.protocols import ( + ACActionHandler, + AccessControlMembershipRepositoryProtocol, +) +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/access_control/protocols.py b/backend/airweave/domains/access_control/protocols.py index dc0d3e53f..8227096cc 100644 --- a/backend/airweave/domains/access_control/protocols.py +++ b/backend/airweave/domains/access_control/protocols.py @@ -1,12 +1,22 @@ """Protocols for the access control domain.""" -from typing import List, Optional, Protocol +from __future__ import annotations + +from typing import List, Optional, Protocol, runtime_checkable from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession +from airweave.domains.access_control.actions import ( + ACActionBatch, + ACDeleteAction, + ACInsertAction, + ACUpdateAction, + ACUpsertAction, +) +from airweave.domains.access_control.schemas import AccessContext, MembershipTuple +from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.models.access_control_membership import AccessControlMembership -from airweave.platform.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl @@ -151,3 +161,84 @@ def check_entity_access( ) -> bool: """Check if user can access entity based on access control.""" ... + + +class ACActionResolverProtocol(Protocol): + """Resolves membership tuples to action objects.""" + + async def resolve( + self, + memberships: List[MembershipTuple], + sync_context: SyncContext, + ) -> ACActionBatch: + """Resolve memberships to actions.""" + ... + + +class ACActionDispatcherProtocol(Protocol): + """Dispatches resolved AC membership actions to handlers.""" + + async def dispatch( + self, + batch: ACActionBatch, + sync_context: SyncContext, + ) -> int: + """Dispatch action batch to all handlers.""" + ... + + +@runtime_checkable +class ACActionHandler(Protocol): + """Protocol for access control membership action handlers. + + Handlers receive resolved AC actions and persist them to their destination. + + Contract: + - Handlers MUST be idempotent (safe to retry on failure) + - Handlers MUST raise SyncFailureError for non-recoverable errors + """ + + @property + def name(self) -> str: + """Handler name for logging and debugging.""" + ... + + async def handle_batch( + self, + batch: "ACActionBatch", + sync_context: "SyncContext", + ) -> int: + """Handle a full action batch (main entry point).""" + ... + + async def handle_upserts( + self, + actions: List["ACUpsertAction"], + sync_context: "SyncContext", + ) -> int: + """Handle upsert actions.""" + ... + + async def handle_inserts( + self, + actions: List["ACInsertAction"], + sync_context: "SyncContext", + ) -> int: + """Handle insert actions.""" + ... + + async def handle_updates( + self, + actions: List["ACUpdateAction"], + sync_context: "SyncContext", + ) -> int: + """Handle update actions.""" + ... + + async def handle_deletes( + self, + actions: List["ACDeleteAction"], + sync_context: "SyncContext", + ) -> int: + """Handle delete actions.""" + ... diff --git a/backend/airweave/domains/sync_pipeline/access_control_resolver.py b/backend/airweave/domains/access_control/resolver.py similarity index 93% rename from backend/airweave/domains/sync_pipeline/access_control_resolver.py rename to backend/airweave/domains/access_control/resolver.py index 744d7cab7..182284946 100644 --- a/backend/airweave/domains/sync_pipeline/access_control_resolver.py +++ b/backend/airweave/domains/access_control/resolver.py @@ -6,11 +6,11 @@ from typing import TYPE_CHECKING, List -from airweave.domains.sync_pipeline.types.access_control_actions import ( +from airweave.domains.access_control.actions import ( ACActionBatch, ACUpsertAction, ) -from airweave.platform.access_control.schemas import MembershipTuple +from airweave.domains.access_control.schemas import MembershipTuple if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/platform/access_control/schemas.py b/backend/airweave/domains/access_control/schemas.py similarity index 100% rename from backend/airweave/platform/access_control/schemas.py rename to backend/airweave/domains/access_control/schemas.py diff --git a/backend/airweave/domains/access_control/tests/__init__.py b/backend/airweave/domains/access_control/tests/__init__.py new file mode 100644 index 000000000..33542712f --- /dev/null +++ b/backend/airweave/domains/access_control/tests/__init__.py @@ -0,0 +1 @@ +"""Access control domain tests.""" diff --git a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py b/backend/airweave/domains/access_control/tests/test_membership_tracker.py similarity index 99% rename from backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py rename to backend/airweave/domains/access_control/tests/test_membership_tracker.py index 0ca7e4e1b..355066342 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_acl_membership_tracker.py +++ b/backend/airweave/domains/access_control/tests/test_membership_tracker.py @@ -5,7 +5,7 @@ import pytest -from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker +from airweave.domains.access_control.membership_tracker import ACLMembershipTracker @pytest.fixture diff --git a/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py b/backend/airweave/domains/access_control/tests/test_pipeline.py similarity index 98% rename from backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py rename to backend/airweave/domains/access_control/tests/test_pipeline.py index 01c663003..681739aa1 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_acl_reconciliation.py +++ b/backend/airweave/domains/access_control/tests/test_pipeline.py @@ -18,10 +18,10 @@ import pytest -from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline -from airweave.platform.access_control.schemas import ACLChangeType, MembershipChange +from airweave.domains.access_control.pipeline import AccessControlPipeline +from airweave.domains.access_control.schemas import ACLChangeType, MembershipChange -_GET_DB_CTX = "airweave.domains.sync_pipeline.access_control_pipeline.get_db_context" +_GET_DB_CTX = "airweave.domains.access_control.pipeline.get_db_context" # --------------------------------------------------------------------------- diff --git a/backend/airweave/domains/converters/_base.py b/backend/airweave/domains/converters/_base.py index 7be5f20fd..f949f3629 100644 --- a/backend/airweave/domains/converters/_base.py +++ b/backend/airweave/domains/converters/_base.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional from airweave.core.logging import logger -from airweave.core.protocols.ocr import OcrProvider +from airweave.domains.ocr.protocols import OcrProvider class BaseTextConverter(ABC): diff --git a/backend/airweave/domains/converters/registry.py b/backend/airweave/domains/converters/registry.py index ef6f03f8b..4ae25b315 100644 --- a/backend/airweave/domains/converters/registry.py +++ b/backend/airweave/domains/converters/registry.py @@ -4,7 +4,6 @@ from typing import Dict, Optional -from airweave.core.protocols.ocr import OcrProvider from airweave.domains.converters._base import BaseTextConverter from airweave.domains.converters.code import CodeConverter from airweave.domains.converters.docx import DocxConverter @@ -14,6 +13,7 @@ from airweave.domains.converters.txt import TxtConverter from airweave.domains.converters.web import WebConverter from airweave.domains.converters.xlsx import XlsxConverter +from airweave.domains.ocr.protocols import OcrProvider class ConverterRegistry: diff --git a/backend/airweave/domains/ocr/__init__.py b/backend/airweave/domains/ocr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/adapters/ocr/docling.py b/backend/airweave/domains/ocr/docling.py similarity index 100% rename from backend/airweave/adapters/ocr/docling.py rename to backend/airweave/domains/ocr/docling.py diff --git a/backend/airweave/domains/ocr/fakes/__init__.py b/backend/airweave/domains/ocr/fakes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/adapters/ocr/fake.py b/backend/airweave/domains/ocr/fakes/provider.py similarity index 100% rename from backend/airweave/adapters/ocr/fake.py rename to backend/airweave/domains/ocr/fakes/provider.py diff --git a/backend/airweave/adapters/ocr/fallback.py b/backend/airweave/domains/ocr/fallback.py similarity index 96% rename from backend/airweave/adapters/ocr/fallback.py rename to backend/airweave/domains/ocr/fallback.py index d268a15c5..699c4b0f1 100644 --- a/backend/airweave/adapters/ocr/fallback.py +++ b/backend/airweave/domains/ocr/fallback.py @@ -12,7 +12,8 @@ from airweave.core.logging import logger if TYPE_CHECKING: - from airweave.core.protocols import CircuitBreaker, OcrProvider + from airweave.core.protocols import CircuitBreaker + from airweave.domains.ocr.protocols import OcrProvider class FallbackOcrProvider: diff --git a/backend/airweave/domains/ocr/mistral/__init__.py b/backend/airweave/domains/ocr/mistral/__init__.py new file mode 100644 index 000000000..7b3dc4af7 --- /dev/null +++ b/backend/airweave/domains/ocr/mistral/__init__.py @@ -0,0 +1 @@ +"""Mistral OCR implementation.""" diff --git a/backend/airweave/platform/ocr/mistral/compressor.py b/backend/airweave/domains/ocr/mistral/compressor.py similarity index 98% rename from backend/airweave/platform/ocr/mistral/compressor.py rename to backend/airweave/domains/ocr/mistral/compressor.py index e07196972..85e751ed0 100644 --- a/backend/airweave/platform/ocr/mistral/compressor.py +++ b/backend/airweave/domains/ocr/mistral/compressor.py @@ -15,8 +15,8 @@ if TYPE_CHECKING: from PIL import Image +from airweave.domains.ocr.mistral.models import CompressionResult from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.ocr.mistral.models import CompressionResult # Quality levels to try, from highest to lowest. _QUALITY_STEPS = range(85, 19, -10) diff --git a/backend/airweave/platform/ocr/mistral/converter.py b/backend/airweave/domains/ocr/mistral/converter.py similarity index 98% rename from backend/airweave/platform/ocr/mistral/converter.py rename to backend/airweave/domains/ocr/mistral/converter.py index d4e10607c..243fcfa36 100644 --- a/backend/airweave/platform/ocr/mistral/converter.py +++ b/backend/airweave/domains/ocr/mistral/converter.py @@ -26,9 +26,8 @@ from airweave.core.logging import logger from airweave.domains.converters.text_extractors.pptx import extract_pptx_text -from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError -from airweave.platform.ocr.mistral.compressor import compress_image -from airweave.platform.ocr.mistral.models import ( +from airweave.domains.ocr.mistral.compressor import compress_image +from airweave.domains.ocr.mistral.models import ( IMAGE_EXTENSIONS, BatchFileGroup, DirectResult, @@ -36,12 +35,13 @@ OcrResult, PreparedBatch, ) -from airweave.platform.ocr.mistral.ocr_client import MistralOcrClient -from airweave.platform.ocr.mistral.splitters import ( +from airweave.domains.ocr.mistral.ocr_client import MistralOcrClient +from airweave.domains.ocr.mistral.splitters import ( DocxSplitter, PdfSplitter, RecursiveSplitter, ) +from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError # Mistral upload limit. MAX_FILE_SIZE_BYTES = 50_000_000 # 50 MB diff --git a/backend/airweave/platform/ocr/mistral/models.py b/backend/airweave/domains/ocr/mistral/models.py similarity index 100% rename from backend/airweave/platform/ocr/mistral/models.py rename to backend/airweave/domains/ocr/mistral/models.py diff --git a/backend/airweave/platform/ocr/mistral/ocr_client.py b/backend/airweave/domains/ocr/mistral/ocr_client.py similarity index 99% rename from backend/airweave/platform/ocr/mistral/ocr_client.py rename to backend/airweave/domains/ocr/mistral/ocr_client.py index 93b3b641b..40fe8ce6a 100644 --- a/backend/airweave/platform/ocr/mistral/ocr_client.py +++ b/backend/airweave/domains/ocr/mistral/ocr_client.py @@ -20,11 +20,11 @@ from airweave.core.config import settings from airweave.core.logging import logger -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.platform.ocr.mistral.models import ( +from airweave.domains.ocr.mistral.models import ( FileChunk, OcrResult, ) +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.rate_limiters import MistralRateLimiter # --------------------------------------------------------------------------- diff --git a/backend/airweave/platform/ocr/mistral/splitters.py b/backend/airweave/domains/ocr/mistral/splitters.py similarity index 100% rename from backend/airweave/platform/ocr/mistral/splitters.py rename to backend/airweave/domains/ocr/mistral/splitters.py diff --git a/backend/airweave/core/protocols/ocr.py b/backend/airweave/domains/ocr/protocols.py similarity index 100% rename from backend/airweave/core/protocols/ocr.py rename to backend/airweave/domains/ocr/protocols.py diff --git a/backend/airweave/domains/sync_pipeline/entity/__init__.py b/backend/airweave/domains/sync_pipeline/entity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/domains/sync_pipeline/types/entity_actions.py b/backend/airweave/domains/sync_pipeline/entity/actions.py similarity index 100% rename from backend/airweave/domains/sync_pipeline/types/entity_actions.py rename to backend/airweave/domains/sync_pipeline/entity/actions.py diff --git a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py b/backend/airweave/domains/sync_pipeline/entity/dispatcher.py similarity index 97% rename from backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py rename to backend/airweave/domains/sync_pipeline/entity/dispatcher.py index 04ef5f571..4e47402eb 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_dispatcher.py +++ b/backend/airweave/domains/sync_pipeline/entity/dispatcher.py @@ -7,9 +7,9 @@ import asyncio from typing import TYPE_CHECKING, List, Optional +from airweave.domains.sync_pipeline.entity.actions import EntityActionBatch +from airweave.domains.sync_pipeline.entity.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler -from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py b/backend/airweave/domains/sync_pipeline/entity/dispatcher_builder.py similarity index 92% rename from backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py rename to backend/airweave/domains/sync_pipeline/entity/dispatcher_builder.py index 8b0206fa7..de489e2f8 100644 --- a/backend/airweave/domains/sync_pipeline/entity_dispatcher_builder.py +++ b/backend/airweave/domains/sync_pipeline/entity/dispatcher_builder.py @@ -5,11 +5,11 @@ from airweave.core.logging import ContextualLogger from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.sync_pipeline.config import SyncConfig -from airweave.domains.sync_pipeline.entity_action_dispatcher import EntityActionDispatcher -from airweave.domains.sync_pipeline.handlers.arf import ArfHandler -from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler -from airweave.domains.sync_pipeline.handlers.entity_postgres import EntityPostgresHandler -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.entity.dispatcher import EntityActionDispatcher +from airweave.domains.sync_pipeline.entity.handlers.arf import ArfHandler +from airweave.domains.sync_pipeline.entity.handlers.destination import DestinationHandler +from airweave.domains.sync_pipeline.entity.handlers.postgres import EntityPostgresHandler +from airweave.domains.sync_pipeline.entity.handlers.protocol import EntityActionHandler from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol from airweave.platform.destinations._base import BaseDestination diff --git a/backend/airweave/domains/sync_pipeline/entity/handlers/__init__.py b/backend/airweave/domains/sync_pipeline/entity/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/airweave/domains/sync_pipeline/handlers/arf.py b/backend/airweave/domains/sync_pipeline/entity/handlers/arf.py similarity index 97% rename from backend/airweave/domains/sync_pipeline/handlers/arf.py rename to backend/airweave/domains/sync_pipeline/entity/handlers/arf.py index 3a12503e6..f518ecc62 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/arf.py +++ b/backend/airweave/domains/sync_pipeline/entity/handlers/arf.py @@ -6,14 +6,14 @@ from typing import TYPE_CHECKING, List -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler -from airweave.domains.sync_pipeline.types.entity_actions import ( +from airweave.domains.sync_pipeline.entity.actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) +from airweave.domains.sync_pipeline.entity.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/handlers/destination.py b/backend/airweave/domains/sync_pipeline/entity/handlers/destination.py similarity index 98% rename from backend/airweave/domains/sync_pipeline/handlers/destination.py rename to backend/airweave/domains/sync_pipeline/entity/handlers/destination.py index e76998f94..32e8dd315 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/destination.py +++ b/backend/airweave/domains/sync_pipeline/entity/handlers/destination.py @@ -10,15 +10,15 @@ import httpcore import httpx -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler -from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol -from airweave.domains.sync_pipeline.types.entity_actions import ( +from airweave.domains.sync_pipeline.entity.actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) +from airweave.domains.sync_pipeline.entity.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError +from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol from airweave.platform.destinations._base import BaseDestination if TYPE_CHECKING: diff --git a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py b/backend/airweave/domains/sync_pipeline/entity/handlers/postgres.py similarity index 98% rename from backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py rename to backend/airweave/domains/sync_pipeline/entity/handlers/postgres.py index b16fd6ab4..ec1d0fc90 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/entity_postgres.py +++ b/backend/airweave/domains/sync_pipeline/entity/handlers/postgres.py @@ -12,14 +12,14 @@ from airweave import schemas from airweave.db.session import get_db_context from airweave.domains.entities.protocols import EntityRepositoryProtocol -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.protocol import EntityActionHandler -from airweave.domains.sync_pipeline.types.entity_actions import ( +from airweave.domains.sync_pipeline.entity.actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityUpdateAction, ) +from airweave.domains.sync_pipeline.entity.handlers.protocol import EntityActionHandler +from airweave.domains.sync_pipeline.exceptions import SyncFailureError if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext diff --git a/backend/airweave/domains/sync_pipeline/handlers/protocol.py b/backend/airweave/domains/sync_pipeline/entity/handlers/protocol.py similarity index 51% rename from backend/airweave/domains/sync_pipeline/handlers/protocol.py rename to backend/airweave/domains/sync_pipeline/entity/handlers/protocol.py index 80b510172..84d5595ff 100644 --- a/backend/airweave/domains/sync_pipeline/handlers/protocol.py +++ b/backend/airweave/domains/sync_pipeline/entity/handlers/protocol.py @@ -1,18 +1,11 @@ -"""Protocols for action handlers.""" +"""Protocols for entity action handlers.""" from typing import TYPE_CHECKING, Any, List, Protocol, runtime_checkable if TYPE_CHECKING: from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime - from airweave.domains.sync_pipeline.types.access_control_actions import ( - ACActionBatch, - ACDeleteAction, - ACInsertAction, - ACUpdateAction, - ACUpsertAction, - ) - from airweave.domains.sync_pipeline.types.entity_actions import ( + from airweave.domains.sync_pipeline.entity.actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, @@ -87,71 +80,3 @@ async def handle_orphan_cleanup( ) -> Any: """Handle orphaned entity cleanup at sync end.""" ... - - -@runtime_checkable -class ACActionHandler(Protocol): - """Protocol for access control membership action handlers. - - Handlers receive resolved AC actions and persist them to their destination. - - Contract: - - Handlers MUST be idempotent (safe to retry on failure) - - Handlers MUST raise SyncFailureError for non-recoverable errors - """ - - @property - def name(self) -> str: - """Handler name for logging and debugging.""" - ... - - async def handle_batch( - self, - batch: "ACActionBatch", - sync_context: "SyncContext", - ) -> int: - """Handle a full action batch (main entry point). - - Args: - batch: Access control action batch - sync_context: Sync context - - Returns: - Number of memberships processed - - Raises: - SyncFailureError: If any operation fails - """ - ... - - async def handle_upserts( - self, - actions: List["ACUpsertAction"], - sync_context: "SyncContext", - ) -> int: - """Handle upsert actions.""" - ... - - async def handle_inserts( - self, - actions: List["ACInsertAction"], - sync_context: "SyncContext", - ) -> int: - """Handle insert actions.""" - ... - - async def handle_updates( - self, - actions: List["ACUpdateAction"], - sync_context: "SyncContext", - ) -> int: - """Handle update actions.""" - ... - - async def handle_deletes( - self, - actions: List["ACDeleteAction"], - sync_context: "SyncContext", - ) -> int: - """Handle delete actions.""" - ... diff --git a/backend/airweave/domains/sync_pipeline/entity_pipeline.py b/backend/airweave/domains/sync_pipeline/entity/pipeline.py similarity index 99% rename from backend/airweave/domains/sync_pipeline/entity_pipeline.py rename to backend/airweave/domains/sync_pipeline/entity/pipeline.py index 117f6b9ed..c036733a7 100644 --- a/backend/airweave/domains/sync_pipeline/entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/entity/pipeline.py @@ -20,6 +20,7 @@ from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime +from airweave.domains.sync_pipeline.entity.actions import EntityActionBatch from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.domains.sync_pipeline.pipeline.cleanup_service import cleanup_service from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker @@ -28,7 +29,6 @@ EntityActionDispatcherProtocol, EntityActionResolverProtocol, ) -from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: diff --git a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/entity/resolver.py similarity index 99% rename from backend/airweave/domains/sync_pipeline/entity_action_resolver.py rename to backend/airweave/domains/sync_pipeline/entity/resolver.py index 430dff8c8..aee76c423 100644 --- a/backend/airweave/domains/sync_pipeline/entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/entity/resolver.py @@ -10,14 +10,14 @@ from airweave import models from airweave.db.session import get_db_context from airweave.domains.entities.protocols import EntityRepositoryProtocol -from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.types.entity_actions import ( +from airweave.domains.sync_pipeline.entity.actions import ( EntityActionBatch, EntityDeleteAction, EntityInsertAction, EntityKeepAction, EntityUpdateAction, ) +from airweave.domains.sync_pipeline.exceptions import SyncFailureError from airweave.platform.entities._base import BaseEntity, DeletionEntity if TYPE_CHECKING: diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index 38b675890..aa0a51f3c 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -21,31 +21,31 @@ from airweave.core.exceptions import NotFoundException from airweave.core.logging import LoggerConfigurator, logger from airweave.core.protocols.event_bus import EventBus +from airweave.domains.access_control.dispatcher import ACActionDispatcher +from airweave.domains.access_control.membership_tracker import ACLMembershipTracker +from airweave.domains.access_control.pipeline import AccessControlPipeline +from airweave.domains.access_control.postgres_handler import ACPostgresHandler from airweave.domains.access_control.protocols import AccessControlMembershipRepositoryProtocol +from airweave.domains.access_control.resolver import ACActionResolver from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.entities.protocols import EntityRepositoryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol -from airweave.domains.sync_pipeline.access_control_dispatcher import ACActionDispatcher -from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline -from airweave.domains.sync_pipeline.access_control_resolver import ACActionResolver from airweave.domains.sync_pipeline.builders import SyncContextBuilder from airweave.domains.sync_pipeline.builders.destinations import DestinationsContextBuilder from airweave.domains.sync_pipeline.builders.tracking import TrackingContextBuilder from airweave.domains.sync_pipeline.config import SyncConfig, SyncConfigBuilder from airweave.domains.sync_pipeline.contexts.infra import InfraContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime -from airweave.domains.sync_pipeline.entity_dispatcher_builder import EntityDispatcherBuilder -from airweave.domains.sync_pipeline.handlers import ACPostgresHandler +from airweave.domains.sync_pipeline.entity.dispatcher_builder import EntityDispatcherBuilder from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator -from airweave.domains.sync_pipeline.pipeline.acl_membership_tracker import ACLMembershipTracker from airweave.domains.sync_pipeline.pipeline.entity_tracker import EntityTracker from airweave.domains.sync_pipeline.protocols import ChunkEmbedProcessorProtocol from airweave.domains.sync_pipeline.stream import AsyncSourceStream from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.domains.usage.protocols import UsageLimitCheckerProtocol -from .entity_action_resolver import EntityActionResolver -from .entity_pipeline import EntityPipeline +from .entity.pipeline import EntityPipeline +from .entity.resolver import EntityActionResolver class SyncFactory: diff --git a/backend/airweave/domains/sync_pipeline/handlers/__init__.py b/backend/airweave/domains/sync_pipeline/handlers/__init__.py deleted file mode 100644 index 3106c83a9..000000000 --- a/backend/airweave/domains/sync_pipeline/handlers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Sync pipeline handlers — entity and access control action handlers.""" - -from airweave.domains.sync_pipeline.handlers.access_control_postgres import ACPostgresHandler -from airweave.domains.sync_pipeline.handlers.protocol import ACActionHandler, EntityActionHandler - -__all__ = [ - "EntityActionHandler", - "ACActionHandler", - "ACPostgresHandler", -] diff --git a/backend/airweave/domains/sync_pipeline/orchestrator.py b/backend/airweave/domains/sync_pipeline/orchestrator.py index 34a5ec57f..3f2f3bafc 100644 --- a/backend/airweave/domains/sync_pipeline/orchestrator.py +++ b/backend/airweave/domains/sync_pipeline/orchestrator.py @@ -12,10 +12,10 @@ from airweave.core.sync_cursor_service import sync_cursor_service from airweave.core.sync_job_service import sync_job_service from airweave.db.session import get_db_context -from airweave.domains.sync_pipeline.access_control_pipeline import AccessControlPipeline +from airweave.domains.access_control.pipeline import AccessControlPipeline from airweave.domains.sync_pipeline.contexts import SyncContext from airweave.domains.sync_pipeline.contexts.runtime import SyncRuntime -from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline +from airweave.domains.sync_pipeline.entity.pipeline import EntityPipeline from airweave.domains.sync_pipeline.exceptions import EntityProcessingError, SyncFailureError from airweave.domains.sync_pipeline.stream import AsyncSourceStream from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool diff --git a/backend/airweave/domains/sync_pipeline/protocols.py b/backend/airweave/domains/sync_pipeline/protocols.py index edcffe683..d3de2eb39 100644 --- a/backend/airweave/domains/sync_pipeline/protocols.py +++ b/backend/airweave/domains/sync_pipeline/protocols.py @@ -7,9 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas -from airweave.domains.sync_pipeline.types.access_control_actions import ACActionBatch -from airweave.domains.sync_pipeline.types.entity_actions import EntityActionBatch -from airweave.platform.access_control.schemas import MembershipTuple +from airweave.domains.sync_pipeline.entity.actions import EntityActionBatch from airweave.platform.entities._base import BaseEntity if TYPE_CHECKING: @@ -93,30 +91,6 @@ async def cleanup_temp_files(self, sync_context: SyncContext, runtime: SyncRunti ... -class ACActionResolverProtocol(Protocol): - """Resolves membership tuples to action objects.""" - - async def resolve( - self, - memberships: List[MembershipTuple], - sync_context: SyncContext, - ) -> ACActionBatch: - """Resolve memberships to actions.""" - ... - - -class ACActionDispatcherProtocol(Protocol): - """Dispatches resolved AC membership actions to handlers.""" - - async def dispatch( - self, - batch: ACActionBatch, - sync_context: SyncContext, - ) -> int: - """Dispatch action batch to all handlers.""" - ... - - class SyncFactoryProtocol(Protocol): """Builds a SyncOrchestrator for a given sync run.""" diff --git a/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py index ff0910faf..f83d90330 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_destination_handler.py @@ -13,9 +13,9 @@ import pytest from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.handlers.destination import DestinationHandler +from airweave.domains.sync_pipeline.entity.handlers.destination import DestinationHandler -_ASYNC_SLEEP = "airweave.domains.sync_pipeline.handlers.destination.asyncio.sleep" +_ASYNC_SLEEP = "airweave.domains.sync_pipeline.entity.handlers.destination.asyncio.sleep" def _make_mock_destination(soft_fail=False): diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py index 5de9f8e75..e2cfcc7eb 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_action_resolver.py @@ -5,9 +5,9 @@ import pytest -from airweave.domains.sync_pipeline.entity_action_resolver import EntityActionResolver +from airweave.domains.sync_pipeline.entity.resolver import EntityActionResolver from airweave.domains.sync_pipeline.exceptions import SyncFailureError -from airweave.domains.sync_pipeline.types.entity_actions import ( +from airweave.domains.sync_pipeline.entity.actions import ( EntityInsertAction, EntityKeepAction, EntityUpdateAction, diff --git a/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py index eff579a5c..e08c633c8 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_entity_pipeline.py @@ -5,7 +5,7 @@ import pytest -from airweave.domains.sync_pipeline.entity_pipeline import EntityPipeline +from airweave.domains.sync_pipeline.entity.pipeline import EntityPipeline # --------------------------------------------------------------------------- # Constructor diff --git a/backend/airweave/domains/sync_pipeline/types/__init__.py b/backend/airweave/domains/sync_pipeline/types/__init__.py deleted file mode 100644 index a42fc90a7..000000000 --- a/backend/airweave/domains/sync_pipeline/types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Sync pipeline types — action dataclasses for entity and access control.""" diff --git a/backend/airweave/platform/access_control/__init__.py b/backend/airweave/platform/access_control/__init__.py deleted file mode 100644 index c72868f79..000000000 --- a/backend/airweave/platform/access_control/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Access control module for permission resolution and filtering.""" - -from .schemas import AccessContext, MembershipTuple - -__all__ = ["AccessContext", "MembershipTuple"] diff --git a/backend/airweave/platform/ocr/__init__.py b/backend/airweave/platform/ocr/__init__.py deleted file mode 100644 index 014de836a..000000000 --- a/backend/airweave/platform/ocr/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""OCR provider implementations. - -Re-exports :class:`MistralOCR` so consumers can do:: - - from airweave.platform.ocr import MistralOCR -""" - -from airweave.platform.ocr.mistral.converter import MistralOCR - -__all__ = ["MistralOCR"] diff --git a/backend/airweave/platform/ocr/mistral/__init__.py b/backend/airweave/platform/ocr/mistral/__init__.py deleted file mode 100644 index a9fef77ad..000000000 --- a/backend/airweave/platform/ocr/mistral/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Mistral OCR provider package. - -Re-exports :class:`MistralOCR` for convenience. -""" - -from airweave.platform.ocr.mistral.converter import MistralOCR - -__all__ = ["MistralOCR"] diff --git a/backend/airweave/platform/sources/_base.py b/backend/airweave/platform/sources/_base.py index a3095d134..d3b17a682 100644 --- a/backend/airweave/platform/sources/_base.py +++ b/backend/airweave/platform/sources/_base.py @@ -20,7 +20,7 @@ ) if TYPE_CHECKING: - from airweave.platform.access_control.schemas import MembershipTuple + from airweave.domains.access_control.schemas import MembershipTuple import httpx from pydantic import BaseModel diff --git a/backend/airweave/platform/sources/sharepoint2019v2/ldap.py b/backend/airweave/platform/sources/sharepoint2019v2/ldap.py index 2ffaad781..c067cb659 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/ldap.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/ldap.py @@ -24,7 +24,7 @@ from pyasn1.codec.ber import decoder as ber_decoder from pyasn1.type.univ import Sequence -from airweave.platform.access_control.schemas import ( +from airweave.domains.access_control.schemas import ( ACLChangeType, MembershipChange, MembershipTuple, diff --git a/backend/airweave/platform/sources/sharepoint2019v2/source.py b/backend/airweave/platform/sources/sharepoint2019v2/source.py index 73be18992..36c5656fd 100644 --- a/backend/airweave/platform/sources/sharepoint2019v2/source.py +++ b/backend/airweave/platform/sources/sharepoint2019v2/source.py @@ -24,9 +24,9 @@ from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from airweave.domains.access_control.schemas import MembershipTuple from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.access_control.schemas import MembershipTuple from airweave.platform.configs.auth import SharePoint2019V2AuthConfig from airweave.platform.configs.config import SharePoint2019V2Config from airweave.platform.cursors.sharepoint2019v2 import SharePoint2019V2Cursor diff --git a/backend/airweave/platform/sources/sharepoint_online/graph_groups.py b/backend/airweave/platform/sources/sharepoint_online/graph_groups.py index a86a44d92..9f86bc425 100644 --- a/backend/airweave/platform/sources/sharepoint_online/graph_groups.py +++ b/backend/airweave/platform/sources/sharepoint_online/graph_groups.py @@ -11,7 +11,7 @@ import httpx -from airweave.platform.access_control.schemas import MembershipTuple +from airweave.domains.access_control.schemas import MembershipTuple from airweave.platform.sources.sharepoint_online.acl import format_entra_group_id GRAPH_BASE_URL = "https://graph.microsoft.com/v1.0" diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index 5b389711f..6b832c8e1 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -27,9 +27,9 @@ import httpx +from airweave.domains.access_control.schemas import MembershipTuple from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.access_control.schemas import MembershipTuple from airweave.platform.configs.config import SharePointOnlineConfig from airweave.platform.cursors.sharepoint_online import SharePointOnlineCursor from airweave.platform.decorators import source diff --git a/backend/conftest.py b/backend/conftest.py index 907d503fc..26914e317 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -106,7 +106,7 @@ def fake_circuit_breaker(): @pytest.fixture def fake_ocr_provider(): """Fake OcrProvider that returns canned markdown.""" - from airweave.adapters.ocr.fake import FakeOcrProvider + from airweave.domains.ocr.fakes.provider import FakeOcrProvider return FakeOcrProvider() diff --git a/backend/tests/unit/api/test_admin_user_principals.py b/backend/tests/unit/api/test_admin_user_principals.py index 544a028ab..f304cc9e9 100644 --- a/backend/tests/unit/api/test_admin_user_principals.py +++ b/backend/tests/unit/api/test_admin_user_principals.py @@ -7,7 +7,7 @@ from airweave.api.v1.endpoints.admin import admin_get_user_principals from airweave.domains.collections.fakes.repository import FakeCollectionRepository -from airweave.platform.access_control.schemas import AccessContext +from airweave.domains.access_control.schemas import AccessContext @pytest.fixture diff --git a/backend/tests/unit/core/test_ocr_provider_factory.py b/backend/tests/unit/core/test_ocr_provider_factory.py index 11314159b..22d57ed10 100644 --- a/backend/tests/unit/core/test_ocr_provider_factory.py +++ b/backend/tests/unit/core/test_ocr_provider_factory.py @@ -3,12 +3,10 @@ import types from unittest.mock import patch -import pytest - from airweave.adapters.circuit_breaker.fake import FakeCircuitBreaker -from airweave.adapters.ocr.fake import FakeOcrProvider -from airweave.adapters.ocr.fallback import FallbackOcrProvider from airweave.core.container.factory import _create_ocr_provider +from airweave.domains.ocr.fakes.provider import FakeOcrProvider +from airweave.domains.ocr.fallback import FallbackOcrProvider def _make_settings(docling_base_url=None): @@ -23,7 +21,7 @@ def test_returns_none_when_no_providers(self): settings = _make_settings(docling_base_url=None) with patch( - "airweave.core.container.factory.MistralOcrAdapter", + "airweave.core.container.factory.MistralOCR", side_effect=RuntimeError("no key"), ): result = _create_ocr_provider(cb, settings) @@ -36,7 +34,7 @@ def test_returns_fallback_with_mistral_only(self): settings = _make_settings(docling_base_url=None) with patch( - "airweave.core.container.factory.MistralOcrAdapter", + "airweave.core.container.factory.MistralOCR", return_value=FakeOcrProvider(), ): result = _create_ocr_provider(cb, settings) @@ -50,7 +48,7 @@ def test_returns_fallback_with_both_providers(self): with ( patch( - "airweave.core.container.factory.MistralOcrAdapter", + "airweave.core.container.factory.MistralOCR", return_value=FakeOcrProvider(), ), patch( diff --git a/backend/tests/unit/domains/access_control/test_broker.py b/backend/tests/unit/domains/access_control/test_broker.py index 5f014bfd4..328341f1c 100644 --- a/backend/tests/unit/domains/access_control/test_broker.py +++ b/backend/tests/unit/domains/access_control/test_broker.py @@ -6,7 +6,7 @@ import pytest from airweave.domains.access_control.broker import AccessBroker -from airweave.platform.access_control.schemas import AccessContext +from airweave.domains.access_control.schemas import AccessContext from airweave.platform.entities._base import AccessControl diff --git a/backend/tests/unit/domains/ocr/__init__.py b/backend/tests/unit/domains/ocr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/unit/adapters/test_docling_ocr.py b/backend/tests/unit/domains/ocr/test_docling.py similarity index 92% rename from backend/tests/unit/adapters/test_docling_ocr.py rename to backend/tests/unit/domains/ocr/test_docling.py index cce3e315c..8339df801 100644 --- a/backend/tests/unit/adapters/test_docling_ocr.py +++ b/backend/tests/unit/domains/ocr/test_docling.py @@ -7,11 +7,11 @@ import httpx import pytest -from airweave.adapters.ocr.docling import DoclingOcrAdapter -from airweave.core.protocols.ocr import OcrProvider +from airweave.domains.ocr.docling import DoclingOcrAdapter +from airweave.domains.ocr.protocols import OcrProvider # All tests patch the sync health-check in __init__ so they don't need a live docling instance. -_HEALTH_PATCH = "airweave.adapters.ocr.docling.httpx.get" +_HEALTH_PATCH = "airweave.domains.ocr.docling.httpx.get" def _mock_health_response(): @@ -98,7 +98,7 @@ async def test_docling_ocr(case: Case, tmp_path): else: mock_post.return_value = _build_mock_response(case) - with patch("airweave.adapters.ocr.docling.httpx.AsyncClient") as mock_client_cls: + with patch("airweave.domains.ocr.docling.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post = mock_post mock_client.__aenter__ = AsyncMock(return_value=mock_client) @@ -122,7 +122,7 @@ async def test_unsupported_extension(tmp_path): with patch(_HEALTH_PATCH, return_value=_mock_health_response()): adapter = DoclingOcrAdapter(base_url="http://docling:5001") - with patch("airweave.adapters.ocr.docling.httpx.AsyncClient") as mock_client_cls: + with patch("airweave.domains.ocr.docling.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) diff --git a/backend/tests/unit/adapters/test_fallback_ocr.py b/backend/tests/unit/domains/ocr/test_fallback.py similarity index 95% rename from backend/tests/unit/adapters/test_fallback_ocr.py rename to backend/tests/unit/domains/ocr/test_fallback.py index d1927c5d9..d407385a9 100644 --- a/backend/tests/unit/adapters/test_fallback_ocr.py +++ b/backend/tests/unit/domains/ocr/test_fallback.py @@ -6,9 +6,9 @@ import pytest from airweave.adapters.circuit_breaker.fake import FakeCircuitBreaker -from airweave.adapters.ocr.fake import FakeOcrProvider -from airweave.adapters.ocr.fallback import FallbackOcrProvider -from airweave.core.protocols.ocr import OcrProvider +from airweave.domains.ocr.fakes.provider import FakeOcrProvider +from airweave.domains.ocr.fallback import FallbackOcrProvider +from airweave.domains.ocr.protocols import OcrProvider class TestProtocolConformance: diff --git a/backend/tests/unit/platform/sources/test_sharepoint2019v2_dirsync.py b/backend/tests/unit/platform/sources/test_sharepoint2019v2_dirsync.py index 636879a55..df92addf2 100644 --- a/backend/tests/unit/platform/sources/test_sharepoint2019v2_dirsync.py +++ b/backend/tests/unit/platform/sources/test_sharepoint2019v2_dirsync.py @@ -18,7 +18,7 @@ import pytest -from airweave.platform.access_control.schemas import ACLChangeType, MembershipChange +from airweave.domains.access_control.schemas import ACLChangeType, MembershipChange from airweave.platform.sources.sharepoint2019v2.ldap import ( DIRSYNC_FLAGS_BASIC, DIRSYNC_FLAGS_FULL, diff --git a/backend/tests/unit/search/operations/test_access_control_filter.py b/backend/tests/unit/search/operations/test_access_control_filter.py index 46160140f..4cfbd4575 100644 --- a/backend/tests/unit/search/operations/test_access_control_filter.py +++ b/backend/tests/unit/search/operations/test_access_control_filter.py @@ -5,7 +5,7 @@ import pytest -from airweave.platform.access_control.schemas import AccessContext +from airweave.domains.access_control.schemas import AccessContext from airweave.search.operations.access_control_filter import AccessControlFilter from airweave.search.state import SearchState