diff --git a/openrag/components/indexer/vectordb/utils.py b/openrag/components/indexer/vectordb/utils.py index 5bc582e0..0909e129 100644 --- a/openrag/components/indexer/vectordb/utils.py +++ b/openrag/components/indexer/vectordb/utils.py @@ -669,25 +669,30 @@ def get_existing_file_ids(self, partition: str, file_ids: list[str]) -> set[str] ) return {r[0] for r in result.all()} - def add_files_to_workspace(self, workspace_id: str, file_ids: list[str]): + def add_files_to_workspace(self, workspace_id: str, file_ids: list[str]) -> list[str]: + """Add files to a workspace. Returns list of file_ids that could not be resolved.""" with self.Session() as session: # Resolve the workspace's partition to scope the File lookup workspace = session.execute( select(Workspace).where(Workspace.workspace_id == workspace_id) ).scalar_one_or_none() if workspace is None: - return + return file_ids partition = workspace.partition_name - for fid in file_ids: - file_row = session.execute( - select(File.id).where(File.file_id == fid, File.partition_name == partition) - ).scalar_one_or_none() - if file_row is None: - continue - stmt = pg_insert(WorkspaceFile).values(workspace_id=workspace_id, file_id=file_row) + + # Bulk-resolve all file_ids → File.id in a single query + rows = session.execute( + select(File.file_id, File.id).where(File.file_id.in_(file_ids), File.partition_name == partition) + ).all() + id_map = {r[0]: r[1] for r in rows} + missing = [fid for fid in file_ids if fid not in id_map] + + for fid, file_pk in id_map.items(): + stmt = pg_insert(WorkspaceFile).values(workspace_id=workspace_id, file_id=file_pk) stmt = stmt.on_conflict_do_nothing(constraint="uix_workspace_file") session.execute(stmt) session.commit() + return missing def remove_file_from_workspace(self, workspace_id: str, file_id: str) -> bool: """Remove a file from a workspace. Returns True if the association existed, False otherwise.""" diff --git a/openrag/components/indexer/vectordb/vectordb.py b/openrag/components/indexer/vectordb/vectordb.py index b060826a..17a3a13b 100644 --- a/openrag/components/indexer/vectordb/vectordb.py +++ b/openrag/components/indexer/vectordb/vectordb.py @@ -1127,8 +1127,8 @@ async def get_existing_file_ids(self, partition: str, file_ids: list[str]) -> li """Return the subset of file_ids that exist in the given partition.""" return list(self.partition_file_manager.get_existing_file_ids(partition, file_ids)) - async def add_files_to_workspace(self, workspace_id: str, file_ids: list[str]): - self.partition_file_manager.add_files_to_workspace(workspace_id, file_ids) + async def add_files_to_workspace(self, workspace_id: str, file_ids: list[str]) -> list[str]: + return self.partition_file_manager.add_files_to_workspace(workspace_id, file_ids) async def remove_file_from_workspace(self, workspace_id: str, file_id: str) -> bool: return self.partition_file_manager.remove_file_from_workspace(workspace_id, file_id) diff --git a/openrag/routers/workspaces.py b/openrag/routers/workspaces.py index fefc1533..3ed9cfe3 100644 --- a/openrag/routers/workspaces.py +++ b/openrag/routers/workspaces.py @@ -170,11 +170,17 @@ async def add_files_to_workspace( status_code=status.HTTP_404_NOT_FOUND, detail=f"File IDs not found in partition '{partition}': {unknown_ids}", ) - await call_ray_actor_with_timeout( + missing = await call_ray_actor_with_timeout( vectordb.add_files_to_workspace.remote(workspace_id, body.file_ids), timeout=VECTORDB_TIMEOUT, task_description=f"add_files_to_workspace({workspace_id})", ) + if missing: + # TOCTOU: files were deleted between the pre-check and the insert. + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"File IDs not found in partition '{partition}': {sorted(missing)}", + ) return {"status": "added", "file_ids": body.file_ids}