-
Notifications
You must be signed in to change notification settings - Fork 247
feat(weaviate): add tenant support in write_documents with tests (syn… #3056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
862a160
73e97f2
2bc6722
e3f55cf
126d072
c938971
33da4b4
b18d575
5195d08
86365d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -177,6 +177,7 @@ def _clean_connection_settings(self) -> None: | |||||
| _class_name = self._collection_settings.get("class", "Default") | ||||||
| _class_name = _class_name[0].upper() + _class_name[1:] | ||||||
| self._collection_settings["class"] = _class_name | ||||||
|
|
||||||
| # Set the properties if they're not set | ||||||
| self._collection_settings["properties"] = self._collection_settings.get( | ||||||
| "properties", DOCUMENT_COLLECTION_PROPERTIES | ||||||
|
|
@@ -934,14 +935,16 @@ def _handle_failed_objects(failed_objects: list[ErrorObject]) -> NoReturn: | |||||
| ) | ||||||
| raise DocumentStoreError(msg) | ||||||
|
|
||||||
| def _batch_write(self, documents: list[Document]) -> int: | ||||||
| def _batch_write(self, documents: list[Document], tenant: str | None = None) -> int: | ||||||
| """ | ||||||
| Writes document to Weaviate in batches. | ||||||
|
|
||||||
| Documents with the same id will be overwritten. | ||||||
| Raises in case of errors. | ||||||
| """ | ||||||
|
|
||||||
| # Handle tenant at collection level (NOT via kwargs) | ||||||
| collection = self.collection | ||||||
| if tenant is not None: | ||||||
| collection = collection.with_tenant(tenant) | ||||||
|
|
||||||
| with self.client.batch.dynamic() as batch: | ||||||
| for doc in documents: | ||||||
| if not isinstance(doc, Document): | ||||||
|
|
@@ -950,26 +953,28 @@ def _batch_write(self, documents: list[Document]) -> int: | |||||
|
|
||||||
| batch.add_object( | ||||||
| properties=WeaviateDocumentStore._to_data_object(doc), | ||||||
| collection=self.collection.name, | ||||||
| collection=collection.name, | ||||||
| uuid=generate_uuid5(doc.id), | ||||||
| vector=doc.embedding, | ||||||
| ) | ||||||
|
|
||||||
| if failed_objects := self.client.batch.failed_objects: | ||||||
| self._handle_failed_objects(failed_objects) | ||||||
|
|
||||||
| # If the document already exists we get no status message back from Weaviate. | ||||||
| # So we assume that all Documents were written. | ||||||
| return len(documents) | ||||||
|
|
||||||
| async def _batch_write_async(self, documents: list[Document]) -> int: | ||||||
| async def _batch_write_async(self, documents: list[Document], tenant: str | None = None) -> int: | ||||||
| """ | ||||||
| Asynchronously writes document to Weaviate in batches. | ||||||
|
|
||||||
| Documents with the same id will be overwritten. | ||||||
| Raises in case of errors. | ||||||
|
Comment on lines
-967
to
-969
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep comments that are not related to tenant_support unchanged in this pull request. |
||||||
| """ | ||||||
|
|
||||||
| client = await self.async_client | ||||||
|
|
||||||
| # Handle tenant properly | ||||||
| collection = await self.async_collection | ||||||
| if tenant is not None: | ||||||
| collection = collection.with_tenant(tenant) | ||||||
|
|
||||||
| async with client.batch.stream() as batch: | ||||||
| for doc in documents: | ||||||
| if not isinstance(doc, Document): | ||||||
|
|
@@ -978,39 +983,40 @@ async def _batch_write_async(self, documents: list[Document]) -> int: | |||||
|
|
||||||
| await batch.add_object( | ||||||
| properties=WeaviateDocumentStore._to_data_object(doc), | ||||||
| collection=(await self.async_collection).name, | ||||||
| collection=collection.name, | ||||||
| uuid=generate_uuid5(doc.id), | ||||||
| vector=doc.embedding, | ||||||
| ) | ||||||
|
|
||||||
| if failed_objects := client.batch.failed_objects: | ||||||
| self._handle_failed_objects(failed_objects) | ||||||
|
|
||||||
| # If the document already exists we get no status message back from Weaviate. | ||||||
| # So we assume that all Documents were written. | ||||||
| return len(documents) | ||||||
|
|
||||||
| def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: | ||||||
| def _write(self, documents: list[Document], policy: DuplicatePolicy, tenant: str | None = None) -> int: | ||||||
| """ | ||||||
| Writes documents to Weaviate using the specified policy. | ||||||
|
|
||||||
| This doesn't use the batch API, so it's slower than _batch_write. | ||||||
| If policy is set to SKIP it will skip any document that already exists. | ||||||
| If policy is set to FAIL it will raise an exception if any of the documents already exists. | ||||||
| """ | ||||||
| collection = self.collection | ||||||
| if tenant: | ||||||
| collection = collection.with_tenant(tenant) | ||||||
| written = 0 | ||||||
| duplicate_errors_ids = [] | ||||||
| for doc in documents: | ||||||
| if not isinstance(doc, Document): | ||||||
| msg = f"Expected a Document, got '{type(doc)}' instead." | ||||||
| raise ValueError(msg) | ||||||
|
|
||||||
| if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)): | ||||||
| if policy == DuplicatePolicy.SKIP and collection.data.exists(uuid=generate_uuid5(doc.id)): | ||||||
| # This Document already exists, we skip it | ||||||
| continue | ||||||
|
|
||||||
| try: | ||||||
| self.collection.data.insert( | ||||||
| collection.data.insert( | ||||||
| uuid=generate_uuid5(doc.id), | ||||||
| properties=WeaviateDocumentStore._to_data_object(doc), | ||||||
| vector=doc.embedding, | ||||||
|
|
@@ -1025,7 +1031,12 @@ def _write(self, documents: list[Document], policy: DuplicatePolicy) -> int: | |||||
| raise DuplicateDocumentError(msg) | ||||||
| return written | ||||||
|
|
||||||
| async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) -> int: | ||||||
| async def _write_async( | ||||||
| self, | ||||||
| documents: list[Document], | ||||||
| policy: DuplicatePolicy, | ||||||
| tenant: str | None = None, | ||||||
| ) -> int: | ||||||
| """ | ||||||
| Asynchronously writes documents to Weaviate using the specified policy. | ||||||
|
|
||||||
|
|
@@ -1034,16 +1045,15 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) | |||||
| If policy is set to FAIL it will raise an exception if any of the documents already exists. | ||||||
| """ | ||||||
| collection = await self.async_collection | ||||||
|
|
||||||
| if tenant: | ||||||
| collection = collection.with_tenant(tenant) | ||||||
| duplicate_errors_ids = [] | ||||||
| for doc in documents: | ||||||
| if not isinstance(doc, Document): | ||||||
| msg = f"Expected a Document, got '{type(doc)}' instead." | ||||||
| raise ValueError(msg) | ||||||
|
|
||||||
| if policy == DuplicatePolicy.SKIP and await (await self.async_collection).data.exists( | ||||||
| uuid=generate_uuid5(doc.id) | ||||||
| ): | ||||||
| if policy == DuplicatePolicy.SKIP and await collection.data.exists(uuid=generate_uuid5(doc.id)): | ||||||
| # This Document already exists, continue | ||||||
| continue | ||||||
|
|
||||||
|
|
@@ -1063,7 +1073,12 @@ async def _write_async(self, documents: list[Document], policy: DuplicatePolicy) | |||||
| raise DuplicateDocumentError(msg) | ||||||
| return len(documents) | ||||||
|
|
||||||
| def write_documents(self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: | ||||||
| def write_documents( | ||||||
| self, | ||||||
| documents: list[Document], | ||||||
| policy: DuplicatePolicy = DuplicatePolicy.NONE, | ||||||
| tenant: str | None = None, | ||||||
| ) -> int: | ||||||
| """ | ||||||
| Writes documents to Weaviate using the specified policy. | ||||||
|
|
||||||
|
|
@@ -1089,12 +1104,15 @@ def write_documents(self, documents: list[Document], policy: DuplicatePolicy = D | |||||
| The number of documents written. | ||||||
| """ | ||||||
| if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: | ||||||
| return self._batch_write(documents) | ||||||
| return self._batch_write(documents, tenant) | ||||||
|
|
||||||
| return self._write(documents, policy) | ||||||
| return self._write(documents, policy, tenant) | ||||||
|
|
||||||
| async def write_documents_async( | ||||||
| self, documents: list[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE | ||||||
| self, | ||||||
| documents: list[Document], | ||||||
| policy: DuplicatePolicy = DuplicatePolicy.NONE, | ||||||
| tenant: str | None = None, | ||||||
| ) -> int: | ||||||
| """ | ||||||
| Asynchronously writes documents to Weaviate using the specified policy. | ||||||
|
|
@@ -1121,9 +1139,9 @@ async def write_documents_async( | |||||
| The number of documents written. | ||||||
| """ | ||||||
| if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: | ||||||
| return await self._batch_write_async(documents) | ||||||
| return await self._batch_write_async(documents, tenant) | ||||||
|
|
||||||
| return await self._write_async(documents, policy) | ||||||
| return await self._write_async(documents, policy, tenant) | ||||||
|
|
||||||
| def delete_documents(self, document_ids: list[str]) -> None: | ||||||
| """ | ||||||
|
|
@@ -1249,21 +1267,26 @@ def delete_by_filter(self, filters: dict[str, Any]) -> int: | |||||
| Deletes all documents that match the provided filters. | ||||||
|
|
||||||
| :param filters: The filters to apply to select documents for deletion. | ||||||
| For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) | ||||||
| For filter syntax, see Haystack metadata filtering docs. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this link |
||||||
| :returns: The number of documents deleted. | ||||||
| """ | ||||||
| validate_filters(filters) | ||||||
|
|
||||||
| try: | ||||||
| collection = self.collection | ||||||
|
|
||||||
| weaviate_filter = convert_filters(filters) | ||||||
| result = self.collection.data.delete_many(where=weaviate_filter) | ||||||
| result = collection.data.delete_many(where=weaviate_filter) | ||||||
| deleted_count = result.successful | ||||||
|
|
||||||
| logger.info( | ||||||
| "Deleted {n_docs} documents from collection '{collection}' using filters.", | ||||||
| n_docs=deleted_count, | ||||||
| collection=self.collection.name, | ||||||
| collection=collection.name, | ||||||
| ) | ||||||
|
|
||||||
| return deleted_count | ||||||
|
|
||||||
| except weaviate.exceptions.WeaviateQueryError as e: | ||||||
| msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}" | ||||||
| raise DocumentStoreError(msg) from e | ||||||
|
|
@@ -1275,23 +1298,26 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int: | |||||
| """ | ||||||
| Asynchronously deletes all documents that match the provided filters. | ||||||
|
|
||||||
| :param filters: The filters to apply to select documents for deletion. | ||||||
| For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) | ||||||
| :returns: The number of documents deleted. | ||||||
| :param filters: Filters to select documents for deletion. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this link |
||||||
| :returns: Number of deleted documents. | ||||||
| """ | ||||||
| validate_filters(filters) | ||||||
|
|
||||||
| try: | ||||||
| collection = await self.async_collection | ||||||
|
|
||||||
| weaviate_filter = convert_filters(filters) | ||||||
| result = await collection.data.delete_many(where=weaviate_filter) | ||||||
| deleted_count = result.successful | ||||||
|
|
||||||
| logger.info( | ||||||
| "Deleted {n_docs} documents from collection '{collection}' using filters.", | ||||||
| n_docs=deleted_count, | ||||||
| collection=collection.name, | ||||||
| ) | ||||||
|
|
||||||
| return deleted_count | ||||||
|
|
||||||
| except weaviate.exceptions.WeaviateQueryError as e: | ||||||
| msg = f"Failed to delete documents by filter in Weaviate. Error: {e.message}" | ||||||
| raise DocumentStoreError(msg) from e | ||||||
|
|
@@ -1301,12 +1327,11 @@ async def delete_by_filter_async(self, filters: dict[str, Any]) -> int: | |||||
|
|
||||||
| def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: | ||||||
| """ | ||||||
| Updates the metadata of all documents that match the provided filters. | ||||||
| Updates metadata of all documents that match the provided filters. | ||||||
|
|
||||||
| :param filters: The filters to apply to select documents for updating. | ||||||
| For filter syntax, see [Haystack metadata filtering](https://docs.haystack.deepset.ai/docs/metadata-filtering) | ||||||
| :param meta: The metadata fields to update. These will be merged with existing metadata. | ||||||
| :returns: The number of documents updated. | ||||||
| :param filters: Filters to select documents for updating. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this link |
||||||
| :param meta: Metadata fields to update. | ||||||
| :returns: Number of updated documents. | ||||||
| """ | ||||||
| validate_filters(filters) | ||||||
|
|
||||||
|
|
@@ -1315,39 +1340,35 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int | |||||
| raise ValueError(msg) | ||||||
|
|
||||||
| try: | ||||||
| collection = self.collection # ✅ FIX | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| matching_objects = self._query_with_filters(filters) | ||||||
| if not matching_objects: | ||||||
| return 0 | ||||||
|
|
||||||
| # Update each object with the new metadata | ||||||
| # Since metadata is stored flattened in Weaviate properties, we update properties directly | ||||||
|
Comment on lines
-1322
to
-1323
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep comments here unchanged. also in the following lines of code |
||||||
| updated_count = 0 | ||||||
| failed_updates = [] | ||||||
|
|
||||||
| for obj in matching_objects: | ||||||
| try: | ||||||
| # Get current properties | ||||||
| current_properties = obj.properties.copy() if obj.properties else {} | ||||||
|
|
||||||
| # Update with new metadata values | ||||||
| # Note: metadata fields are stored directly in properties (flattened) | ||||||
| for key, value in meta.items(): | ||||||
| current_properties[key] = value | ||||||
|
|
||||||
| # Update the object, preserving the vector | ||||||
| # Get the vector from the object to preserve it during replace | ||||||
| vector: VECTORS | None = None | ||||||
| if isinstance(obj.vector, (list, dict)): | ||||||
| if isinstance(obj.vector, list | dict): | ||||||
| vector = obj.vector | ||||||
|
|
||||||
| self.collection.data.replace( | ||||||
| collection.data.replace( # ✅ FIX | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| uuid=obj.uuid, | ||||||
| properties=current_properties, | ||||||
| vector=vector, | ||||||
| ) | ||||||
|
|
||||||
| updated_count += 1 | ||||||
|
|
||||||
| except Exception as e: | ||||||
| # Collect failed updates but continue with others | ||||||
| obj_properties = obj.properties or {} | ||||||
| id_ = obj_properties.get("_original_id", obj.uuid) | ||||||
| failed_updates.append((id_, str(e))) | ||||||
|
|
@@ -1361,9 +1382,11 @@ def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int | |||||
| logger.info( | ||||||
| "Updated {n_docs} documents in collection '{collection}' using filters.", | ||||||
| n_docs=updated_count, | ||||||
| collection=self.collection.name, | ||||||
| collection=collection.name, | ||||||
| ) | ||||||
|
|
||||||
| return updated_count | ||||||
|
|
||||||
| except weaviate.exceptions.WeaviateQueryError as e: | ||||||
| msg = f"Failed to update documents by filter in Weaviate. Error: {e.message}" | ||||||
| raise DocumentStoreError(msg) from e | ||||||
|
|
@@ -1431,7 +1454,7 @@ async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str, | |||||
| # Update the object, preserving the vector | ||||||
| # Get the vector from the object to preserve it during replace | ||||||
| vector: VECTORS | None = None | ||||||
| if isinstance(obj.vector, (list, dict)): | ||||||
| if isinstance(obj.vector, list | dict): | ||||||
| vector = obj.vector | ||||||
|
|
||||||
| await collection.data.replace( | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||
| import base64 | ||||
| import logging | ||||
| import os | ||||
| import platform | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| from collections.abc import Generator | ||||
| from unittest.mock import MagicMock, patch | ||||
|
|
||||
|
|
@@ -377,6 +378,15 @@ def test_write_documents(self, document_store): | |||
| assert document_store.write_documents([doc]) == 1 | ||||
| assert document_store.count_documents() == 1 | ||||
|
|
||||
| def test_write_documents_with_tenant(self, document_store): | ||||
| doc = Document(content="tenant test doc") | ||||
|
|
||||
| # Write with tenant | ||||
| written = document_store.write_documents([doc], tenant="tenant1") | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test could check that |
||||
|
|
||||
| assert written == 1 | ||||
| assert document_store.count_documents() == 1 | ||||
|
|
||||
| def test_write_documents_with_blob_data(self, document_store, test_files_path): | ||||
| image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") | ||||
| doc = Document(content="test doc", blob=image) | ||||
|
|
@@ -824,6 +834,7 @@ def test_connect_to_local(self): | |||
| ) | ||||
| assert document_store.client | ||||
|
|
||||
| @pytest.mark.skipif(platform.system() == "Windows", reason="EmbeddedDB not supported on Windows") | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Let's not change this part of the code in this PR. If this change makes sense, we should do it in a separate pull request please. |
||||
| def test_connect_to_embedded(self): | ||||
| document_store = WeaviateDocumentStore(embedded_options=EmbeddedOptions()) | ||||
| assert document_store.client | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,6 +104,18 @@ async def test_write_documents_async(self, document_store: WeaviateDocumentStore | |
| assert await document_store.write_documents_async([doc]) == 1 | ||
| assert await document_store.count_documents_async() == 1 | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_write_documents_with_tenant_async(self, document_store): | ||
| doc = Document(content="tenant test doc") | ||
|
|
||
| written = await document_store.write_documents_async([doc], tenant="tenant1") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test could check if |
||
|
|
||
| assert written == 1 | ||
|
|
||
| docs = await document_store.filter_documents_async() | ||
| assert len(docs) == 1 | ||
| assert docs[0].content == "tenant test doc" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_write_documents_with_blob_data_async( | ||
| self, document_store: WeaviateDocumentStore, test_files_path: Path | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep comments that are not related to tenant_support unchanged in this pull request.