diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 45fa51fa84..7f92124f57 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -177,6 +177,7 @@ def _clean_connection_settings(self) -> None: _class_name = self._collection_settings.get("class", "Default") _class_name = _class_name[0].upper() + _class_name[1:] self._collection_settings["class"] = _class_name + # Set the properties if they're not set self._collection_settings["properties"] = self._collection_settings.get( "properties", DOCUMENT_COLLECTION_PROPERTIES @@ -934,14 +935,16 @@ def _handle_failed_objects(failed_objects: list[ErrorObject]) -> NoReturn: ) raise DocumentStoreError(msg) - def _batch_write(self, documents: list[Document]) -> int: + def _batch_write(self, documents: list[Document], tenant: str | None = None) -> int: """ Writes document to Weaviate in batches. - - Documents with the same id will be overwritten. - Raises in case of errors. """ + # Handle tenant at collection level (NOT via kwargs) + collection = self.collection + if tenant is not None: + collection = collection.with_tenant(tenant) + with self.client.batch.dynamic() as batch: for doc in documents: if not isinstance(doc, Document): @@ -950,26 +953,28 @@ def _batch_write(self, documents: list[Document]) -> int: batch.add_object( properties=WeaviateDocumentStore._to_data_object(doc), - collection=self.collection.name, + collection=collection.name, uuid=generate_uuid5(doc.id), vector=doc.embedding, ) + if failed_objects := self.client.batch.failed_objects: self._handle_failed_objects(failed_objects) - # If the document already exists we get no status message back from Weaviate. - # So we assume that all Documents were written. return len(documents) - async def _batch_write_async(self, documents: list[Document]) -> int: + async def _batch_write_async(self, documents: list[Document], tenant: str | None = None) -> int: """ Asynchronously writes document to Weaviate in batches. - - Documents with the same id will be overwritten. - Raises in case of errors. """ + client = await self.async_client + # Handle tenant properly + collection = await self.async_collection + if tenant is not None: + collection = collection.with_tenant(tenant) + async with client.batch.stream() as batch: for doc in documents: if not isinstance(doc, Document): @@ -978,7 +983,7 @@ async def _batch_write_async(self, documents: list[Document]) -> int: await batch.add_object( properties=WeaviateDocumentStore._to_data_object(doc), - collection=(await self.async_collection).name, + collection=collection.name, uuid=generate_uuid5(doc.id), vector=doc.embedding, ) @@ -986,11 +991,9 @@ async def _batch_write_async(self, documents: list[Document]) -> int: if failed_objects := client.batch.failed_objects: self._handle_failed_objects(failed_objects) - # If the document already exists we get no status message back from Weaviate. - # So we assume that all Documents were written. return len(documents) - def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: + def _write(self, documents: list[Document], policy: DuplicatePolicy, tenant: str | None = None) -> int: """ Writes documents to Weaviate using the specified policy. @@ -998,6 +1001,9 @@ def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: If policy is set to SKIP it will skip any document that already exists. If policy is set to FAIL it will raise an exception if any of the documents already exists. """ + collection = self.collection + if tenant: + collection = collection.with_tenant(tenant) written = 0 duplicate_errors_ids = [] for doc in documents: @@ -1005,12 +1011,12 @@ def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)): + if policy == DuplicatePolicy.SKIP and collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue try: - self.collection.data.insert( + collection.data.insert( uuid=generate_uuid5(doc.id), properties=WeaviateDocumentStore._to_data_object(doc), vector=doc.embedding, @@ -1025,7 +1031,12 @@ def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: raise DuplicateDocumentError(msg) return written - async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) -> int: + async def _write_async( + self, + documents: list[Document], + policy: DuplicatePolicy, + tenant: str | None = None, + ) -> int: """ Asynchronously writes documents to Weaviate using the specified policy. @@ -1034,16 +1045,15 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) If policy is set to FAIL it will raise an exception if any of the documents already exists. """ collection = await self.async_collection - + if tenant: + collection = collection.with_tenant(tenant) duplicate_errors_ids = [] for doc in documents: if not isinstance(doc, Document): msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and await (await self.async_collection).data.exists( - uuid=generate_uuid5(doc.id) - ): + if policy == DuplicatePolicy.SKIP and await collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, continue continue @@ -1063,7 +1073,12 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) raise DuplicateDocumentError(msg) return len(documents) - def write_documents(self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + def write_documents( + self, + documents: list[Document], + policy: DuplicatePolicy = DuplicatePolicy.NONE, + tenant: str | None = None, + ) -> int: """ Writes documents to Weaviate using the specified policy. @@ -1089,12 +1104,15 @@ def write_documents(self, documents: list[Document], policy: DuplicatePolicy = D The number of documents written. """ if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: - return self._batch_write(documents) + return self._batch_write(documents, tenant) - return self._write(documents, policy) + return self._write(documents, policy, tenant) async def write_documents_async( - self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE + self, + documents: list[Document], + policy: DuplicatePolicy = DuplicatePolicy.NONE, + tenant: str | None = None, ) -> int: """ Asynchronously writes documents to Weaviate using the specified policy. @@ -1121,9 +1139,9 @@ async def write_documents_async( The number of documents written. """ if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: - return await self._batch_write_async(documents) + return await self._batch_write_async(documents, tenant) - return await self._write_async(documents, policy) + return await self._write_async(documents, policy, tenant) def delete_documents(self, document_ids: list[str]) -> None: """ @@ -1255,15 +1273,20 @@ def delete_by_filter(self, filters: dict[str, Any]) -> int: validate_filters(filters) try: + collection = self.collection + weaviate_filter = convert_filters(filters) - result = self.collection.data.delete_many(where=weaviate_filter) + result = collection.data.delete_many(where=weaviate_filter) deleted_count = result.successful + logger.info( "Deleted {n_docs} documents from collection '{collection}' using filters.", n_docs=deleted_count, - collection=self.collection.name, + collection=collection.name, ) + return deleted_count + except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e @@ -1277,21 +1300,25 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int: :param filters: The filters to apply to select documents for deletion. For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) - :returns: The number of documents deleted. + :returns: Number of deleted documents. """ validate_filters(filters) try: collection = await self.async_collection + weaviate_filter = convert_filters(filters) result = await collection.data.delete_many(where=weaviate_filter) deleted_count = result.successful + logger.info( "Deleted {n_docs} documents from collection '{collection}' using filters.", n_docs=deleted_count, collection=collection.name, ) + return deleted_count + except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e @@ -1301,7 +1328,7 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int: def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: """ - Updates the metadata of all documents that match the provided filters. + Updates metadata of all documents that match the provided filters. :param filters: The filters to apply to select documents for updating. For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) @@ -1315,39 +1342,34 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int raise ValueError(msg) try: + collection = self.collection matching_objects = self._query_with_filters(filters) if not matching_objects: return 0 - # Update each object with the new metadata - # Since metadata is stored flattened in Weaviate properties, we update properties directly updated_count = 0 failed_updates = [] for obj in matching_objects: try: - # Get current properties current_properties = obj.properties.copy() if obj.properties else {} - # Update with new metadata values - # Note: metadata fields are stored directly in properties (flattened) for key, value in meta.items(): current_properties[key] = value - # Update the object, preserving the vector - # Get the vector from the object to preserve it during replace vector: VECTORS | None = None - if isinstance(obj.vector, (list, dict)): + if isinstance(obj.vector, list | dict): vector = obj.vector - self.collection.data.replace( + collection.data.replace( uuid=obj.uuid, properties=current_properties, vector=vector, ) + updated_count += 1 + except Exception as e: - # Collect failed updates but continue with others obj_properties = obj.properties or {} id_ = obj_properties.get("_original_id", obj.uuid) failed_updates.append((id_, str(e))) @@ -1361,9 +1383,11 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int logger.info( "Updated {n_docs} documents in collection '{collection}' using filters.", n_docs=updated_count, - collection=self.collection.name, + collection=collection.name, ) + return updated_count + except weaviate.exceptions.WeaviateQueryError as e: msg = f"Failed to update documents by filter in Weaviate. Error: {e.message}" raise DocumentStoreError(msg) from e @@ -1431,7 +1455,7 @@ async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str, # Update the object, preserving the vector # Get the vector from the object to preserve it during replace vector: VECTORS | None = None - if isinstance(obj.vector, (list, dict)): + if isinstance(obj.vector, list | dict): vector = obj.vector await collection.data.replace( diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 8d1be5f5bc..80adbac2fb 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -377,6 +377,15 @@ def test_write_documents(self, document_store): assert document_store.write_documents([doc]) == 1 assert document_store.count_documents() == 1 + def test_write_documents_with_tenant(self, document_store): + doc = Document(content="tenant test doc") + + with patch.object(document_store, "_batch_write", return_value=1) as mock_write: + written = document_store.write_documents([doc], tenant="tenant1") + + assert written == 1 + mock_write.assert_called_once_with([doc], "tenant1") + def test_write_documents_with_blob_data(self, document_store, test_files_path): image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") doc = Document(content="test doc", blob=image) diff --git a/integrations/weaviate/tests/test_document_store_async.py b/integrations/weaviate/tests/test_document_store_async.py index 62bf334076..2cbb475096 100644 --- a/integrations/weaviate/tests/test_document_store_async.py +++ b/integrations/weaviate/tests/test_document_store_async.py @@ -104,6 +104,16 @@ async def test_write_documents_async(self, document_store: WeaviateDocumentStore assert await document_store.write_documents_async([doc]) == 1 assert await document_store.count_documents_async() == 1 + @pytest.mark.asyncio + async def test_write_documents_with_tenant_async(self, document_store): + doc = Document(content="tenant test doc") + + doc = Document(content="tenant test doc") + + written = await document_store.write_documents_async([doc], tenant="tenant1") + + assert written == 1 + @pytest.mark.asyncio async def test_write_documents_with_blob_data_async( self, document_store: WeaviateDocumentStore, test_files_path: Path