diff --git a/integrations/astra/tests/test_astra_client.py b/integrations/astra/tests/test_astra_client.py new file mode 100644 index 0000000000..5ccc89fb29 --- /dev/null +++ b/integrations/astra/tests/test_astra_client.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace +from unittest import mock + +import pytest +from astrapy.exceptions import CollectionAlreadyExistsException + +from haystack_integrations.document_stores.astra.astra_client import ( + AstraClient, + QueryResponse, + Response, +) + +CLIENT_PATH = "haystack_integrations.document_stores.astra.astra_client.AstraDBClient" + + +CLIENT_KWARGS = { + "api_endpoint": "http://example.com", + "token": "test_token", + "collection_name": "my_collection", + "embedding_dimension": 4, + "similarity_function": "cosine", +} + + +@pytest.fixture +def mock_db(): + with mock.patch(CLIENT_PATH) as patched_client: + yield patched_client.return_value.get_database.return_value + + +@pytest.fixture +def client(mock_db) -> AstraClient: # noqa: ARG001 + return AstraClient(**CLIENT_KWARGS) + + +def test_query_response_get_returns_value(): + match = Response("id1", "text", [0.1], {"k": "v"}, 0.9) + assert QueryResponse(matches=[match]).get("matches") == [match] + + +class TestAstraClientInit: + def test_creates_collection(self, client): + client._astra_db.create_collection.assert_called_once_with( + name="my_collection", + dimension=4, + indexing={"deny": ["metadata._node_content", "content"]}, + ) + assert client._astra_db_collection is client._astra_db.create_collection.return_value + + @pytest.mark.parametrize( + "pre_indexing,warning_match", + [ + (None, "having indexing turned on"), + ({"deny": ["something_else"]}, "unexpected 'indexing' settings"), + ], + ) + def test_preexisting_collection_with_mismatched_indexing_warns(self, mock_db, pre_indexing, warning_match): + mock_db.create_collection.side_effect = CollectionAlreadyExistsException( + text="exists", keyspace="default", collection_name="my_collection" + ) + mock_db.list_collections.return_value = [ + SimpleNamespace(name="my_collection", options=SimpleNamespace(indexing=pre_indexing)) + ] + with pytest.warns(UserWarning, match=warning_match): + AstraClient(**CLIENT_KWARGS) + mock_db.get_collection.assert_called_once_with("my_collection") + + def test_preexisting_collection_with_matching_indexing_reuses_silently(self, mock_db): + mock_db.create_collection.side_effect = CollectionAlreadyExistsException( + text="exists", keyspace="default", collection_name="my_collection" + ) + mock_db.list_collections.return_value = [ + SimpleNamespace( + name="my_collection", + options=SimpleNamespace(indexing={"deny": ["metadata._node_content", "content"]}), + ) + ] + AstraClient(**CLIENT_KWARGS) + mock_db.get_collection.assert_called_once_with("my_collection") + + def test_unrelated_already_exists_reraises(self, mock_db): + mock_db.create_collection.side_effect = CollectionAlreadyExistsException( + text="exists", keyspace="default", collection_name="my_collection" + ) + mock_db.list_collections.return_value = [] + with pytest.raises(CollectionAlreadyExistsException): + AstraClient(**CLIENT_KWARGS) + + +@pytest.mark.parametrize( + "include_metadata,include_values,expected_meta,expected_values", + [ + (True, True, {"meta": {"k": "v"}}, [0.1]), + (False, False, {}, []), + ], +) +def test_format_query_response(include_metadata, include_values, expected_meta, expected_values): + responses = [{"_id": "1", "$similarity": 0.5, "content": "hi", "$vector": [0.1], "meta": {"k": "v"}}] + result = AstraClient._format_query_response( + responses, include_metadata=include_metadata, include_values=include_values + ) + match = result.matches[0] + assert (match.document_id, match.score, match.text) == ("1", 0.5, "hi") + assert match.values == expected_values + assert match.metadata == expected_meta + + +def test_format_query_response_with_none_returns_empty_matches(): + assert AstraClient._format_query_response(None, include_metadata=True, include_values=True).matches == [] + + +class TestAstraClientMethods: + @pytest.mark.parametrize( + "query_kwargs,expected_find_kwargs", + [ + ( + {"vector": [0.1, 0.2, 0.3, 0.4], "top_k": 5}, + {"sort": {"$vector": [0.1, 0.2, 0.3, 0.4]}, "limit": 5, "include_similarity": True}, + ), + ( + {"vector": [0.1] * 4, "top_k": 3, "query_filter": {"meta.k": {"$eq": "v"}}}, + {"filter": {"meta.k": {"$eq": "v"}}}, + ), + ( + {"query_filter": {"meta.k": {"$eq": "v"}}, "top_k": 2}, + {"filter": {"meta.k": {"$eq": "v"}}, "limit": 2}, + ), + ], + ) + def test_query_forwards_args_to_find(self, client, query_kwargs, expected_find_kwargs): + client._astra_db_collection.find.return_value = iter([]) + client.query(**query_kwargs) + actual = client._astra_db_collection.find.call_args.kwargs + for key, value in expected_find_kwargs.items(): + assert actual[key] == value + + def test_find_documents_warns_when_empty(self, client, caplog): + client._astra_db_collection.find.return_value = iter([]) + assert client.find_documents({"filter": {"x": 1}}) == [] + assert "No documents found" in caplog.text + + @pytest.mark.parametrize( + "return_value,expected,log_substring", + [ + ({"_id": "x"}, {"_id": "x"}, ""), + (None, None, "No document found"), + ], + ) + def test_find_one_document(self, client, caplog, return_value, expected, log_substring): + client._astra_db_collection.find_one.return_value = return_value + assert client.find_one_document({"filter": {}}) == expected + if log_substring: + assert log_substring in caplog.text + + def test_get_documents_batches_ids(self, client): + client._astra_db_collection.find.side_effect = [ + iter([{"_id": str(i), "$similarity": None, "content": "a", "$vector": [0.0] * 4} for i in range(20)]), + iter([{"_id": "20", "$similarity": None, "content": "a", "$vector": [0.0] * 4}]), + ] + result = client.get_documents([str(i) for i in range(21)], batch_size=20) + assert len(result.matches) == 21 + assert client._astra_db_collection.find.call_count == 2 + + def test_insert_returns_ids(self, client): + client._astra_db_collection.insert_many.return_value = SimpleNamespace(inserted_ids=[1, 2]) + assert client.insert([{"_id": "1"}, {"_id": "2"}]) == ["1", "2"] + + @pytest.mark.parametrize( + "returned,expected,log_substring", + [ + ({"_id": "1", "content": "x"}, True, ""), + (None, False, "not updated"), + ], + ) + def test_update_document(self, client, caplog, returned, expected, log_substring): + client._astra_db_collection.find_one_and_update.return_value = returned + assert client.update_document({"_id": "1", "content": "x"}, "_id") is expected + if log_substring: + assert log_substring in caplog.text + + @pytest.mark.parametrize( + "delete_kwargs,expected_filter", + [ + ({"ids": ["a", "b"]}, {"_id": {"$in": ["a", "b"]}}), + ({"filters": {"meta.k": {"$eq": "v"}}}, {"meta.k": {"$eq": "v"}}), + ({}, {}), + ], + ) + def test_delete_builds_filter(self, client, delete_kwargs, expected_filter): + client._astra_db_collection.delete_many.return_value = SimpleNamespace(deleted_count=1) + client.delete(**delete_kwargs) + assert client._astra_db_collection.delete_many.call_args.kwargs["filter"] == expected_filter + + def test_delete_all_documents_returns_deleted_count(self, client): + client._astra_db_collection.delete_many.return_value = SimpleNamespace(deleted_count=5) + assert client.delete_all_documents() == 5 + + def test_count_documents_passes_defaults(self, client): + client._astra_db_collection.count_documents.return_value = 7 + assert client.count_documents() == 7 + client._astra_db_collection.count_documents.assert_called_once_with({}, upper_bound=10000) + + def test_distinct_forwards_filter(self, client): + client._astra_db_collection.distinct.return_value = ["a", "b"] + assert client.distinct("meta.k") == ["a", "b"] + client._astra_db_collection.distinct.assert_called_once_with("meta.k", filter=None) + + def test_update_returns_modified_count(self, client): + client._astra_db_collection.update_many.return_value = SimpleNamespace(update_info={"nModified": 4}) + assert client.update(filters={"meta.k": "v"}, update={"$set": {"meta.k": "v2"}}) == 4 diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index e39e828c92..a39bb4ae4c 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -8,7 +8,7 @@ import pytest from haystack import Document -from haystack.document_stores.errors import MissingDocumentError +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import ( CountDocumentsByFilterTest, @@ -18,8 +18,10 @@ GetMetadataFieldsInfoTest, GetMetadataFieldUniqueValuesTest, ) +from haystack.utils import Secret from haystack_integrations.document_stores.astra import AstraDocumentStore +from haystack_integrations.document_stores.astra.errors import AstraDocumentStoreFilterError @pytest.fixture @@ -28,6 +30,15 @@ def mock_auth(monkeypatch): monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token") +@pytest.fixture +def mocked_store(mock_auth): # noqa: ARG001 + """Returns (store, mock_index) with AstraClient fully mocked out.""" + with mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") as mock_client: + mock_index = mock_client.return_value + store = AstraDocumentStore() + yield store, mock_index + + @mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient") def test_init_is_lazy(_mock_client, mock_auth): # noqa _ = AstraDocumentStore() @@ -50,14 +61,10 @@ def test_to_dict(mock_auth): # noqa } -@pytest.mark.usefixtures("mock_auth") -@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_count_documents_by_filter(mock_astra_client): - mock_index = mock_astra_client.return_value +def test_count_documents_by_filter(mocked_store): + store, mock_index = mocked_store mock_index.count_documents.return_value = 2 - store = AstraDocumentStore() - count = store.count_documents_by_filter({"field": "meta.status", "operator": "==", "value": "draft"}) assert count == 2 @@ -66,14 +73,10 @@ def test_count_documents_by_filter(mock_astra_client): ) -@pytest.mark.usefixtures("mock_auth") -@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_count_unique_metadata_by_filter(mock_astra_client): - mock_index = mock_astra_client.return_value +def test_count_unique_metadata_by_filter(mocked_store): + store, mock_index = mocked_store mock_index.distinct.side_effect = [["news", "docs", ["docs", "faq"], None], [1, 2, 2]] - store = AstraDocumentStore() - counts = store.count_unique_metadata_by_filter( {"field": "meta.status", "operator": "==", "value": "published"}, ["category", "priority"] ) @@ -85,17 +88,13 @@ def test_count_unique_metadata_by_filter(mock_astra_client): ] -@pytest.mark.usefixtures("mock_auth") -@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_get_metadata_fields_info(mock_astra_client): - mock_index = mock_astra_client.return_value +def test_get_metadata_fields_info(mocked_store): + store, mock_index = mocked_store mock_index.find_documents.return_value = [ {"content": "Doc 1", "meta": {"category": "news", "priority": 1, "active": True}}, {"content": "Doc 2", "meta": {"category": "docs", "priority": 2.5, "tags": ["a", "b"]}}, ] - store = AstraDocumentStore() - fields_info = store.get_metadata_fields_info() assert fields_info == { @@ -108,28 +107,18 @@ def test_get_metadata_fields_info(mock_astra_client): mock_index.find_documents.assert_called_once_with({}, projection={"content": 1, "meta": 1}) -@pytest.mark.usefixtures("mock_auth") -@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_get_metadata_field_min_max(mock_astra_client): - mock_index = mock_astra_client.return_value +def test_get_metadata_field_min_max(mocked_store): + store, mock_index = mocked_store mock_index.distinct.return_value = [10, 3, 7] - store = AstraDocumentStore() - - result = store.get_metadata_field_min_max("priority") - - assert result == {"min": 3, "max": 10} + assert store.get_metadata_field_min_max("priority") == {"min": 3, "max": 10} mock_index.distinct.assert_called_once_with("meta.priority") -@pytest.mark.usefixtures("mock_auth") -@mock.patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_get_metadata_field_unique_values(mock_astra_client): - mock_index = mock_astra_client.return_value +def test_get_metadata_field_unique_values(mocked_store): + store, mock_index = mocked_store mock_index.distinct.return_value = ["Beta", "alpha", ["gamma", "alphabet"], None] - store = AstraDocumentStore() - values, total_count = store.get_metadata_field_unique_values("category", search_term="alp", from_=0, size=5) assert values == ["alpha", "alphabet"] @@ -137,6 +126,97 @@ def test_get_metadata_field_unique_values(mock_astra_client): mock_index.distinct.assert_called_once_with("meta.category") +@pytest.mark.parametrize( + "api_endpoint,token,match", + [ + ( + Secret.from_env_var("ASTRA_DB_API_ENDPOINT", strict=False), + Secret.from_token("tok"), + "API endpoint", + ), + ( + Secret.from_token("http://example.com"), + Secret.from_env_var("ASTRA_DB_APPLICATION_TOKEN", strict=False), + "authentication token", + ), + ], +) +def test_init_raises_when_secret_resolves_to_none(monkeypatch, api_endpoint, token, match): + monkeypatch.delenv("ASTRA_DB_API_ENDPOINT", raising=False) + monkeypatch.delenv("ASTRA_DB_APPLICATION_TOKEN", raising=False) + with pytest.raises(ValueError, match=match): + AstraDocumentStore(api_endpoint=api_endpoint, token=token) + + +@pytest.mark.parametrize( + "doc,expected_exc,match", + [ + ({"id": "1", "_id": "1", "content": "x"}, Exception, "Duplicate id definitions"), + ({"_id": 42, "content": "x"}, Exception, "is not a string"), + ("not-a-doc", ValueError, "Unsupported type"), + ], +) +def test_write_documents_input_validation_errors(mocked_store, doc, expected_exc, match): + store, _ = mocked_store + with pytest.raises(expected_exc, match=match): + store.write_documents([doc]) + + +def test_write_documents_fail_policy_raises_on_duplicate(mocked_store): + store, mock_index = mocked_store + mock_index.find_documents.return_value = [{"_id": "1"}] + with pytest.raises(DuplicateDocumentError, match="already exists"): + store.write_documents([Document(id="1", content="a")], policy=DuplicatePolicy.FAIL) + + +def test_write_documents_sparse_embedding_is_dropped_with_warning(mocked_store, caplog): + store, mock_index = mocked_store + mock_index.find_documents.return_value = [] + mock_index.insert.return_value = ["1"] + store.write_documents([{"_id": "1", "content": "x", "sparse_embedding": {"indices": [0], "values": [1.0]}}]) + inserted = mock_index.insert.call_args.args[0][0] + assert "sparse_embedding" not in inserted + assert "sparse embeddings in Astra" in caplog.text + + +def test_delete_all_documents_wraps_exception(mocked_store): + store, mock_index = mocked_store + mock_index.delete_all_documents.side_effect = RuntimeError("boom") + with pytest.raises(DocumentStoreError, match="Failed to delete all documents"): + store.delete_all_documents() + + +@pytest.mark.parametrize( + "filters,meta,match", + [ + ("bad", {}, "Filters must be a dictionary"), + ({}, "bad", "Meta must be a dictionary"), + ], +) +def test_update_by_filter_validation_errors(mocked_store, filters, meta, match): + store, _ = mocked_store + with pytest.raises(AstraDocumentStoreFilterError, match=match): + store.update_by_filter(filters=filters, meta=meta) + + +def test_update_by_filter_applies_meta_with_dot_notation(mocked_store): + store, mock_index = mocked_store + mock_index.update.return_value = 4 + count = store.update_by_filter( + filters={"field": "meta.category", "operator": "==", "value": "news"}, + meta={"reviewed": True, "priority": 1}, + ) + assert count == 4 + kwargs = mock_index.update.call_args.kwargs + assert kwargs["filters"] == {"meta.category": {"$eq": "news"}} + assert kwargs["update"] == {"$set": {"meta.reviewed": True, "meta.priority": 1}} + + +def test_infer_metadata_field_type_mixed_types_warn_and_default_to_keyword(caplog): + assert AstraDocumentStore._infer_metadata_field_type([1, "a"]) == "keyword" + assert "mixed metadata types" in caplog.text + + @pytest.mark.integration @pytest.mark.skipif( os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" diff --git a/integrations/astra/tests/test_filters_unit.py b/integrations/astra/tests/test_filters_unit.py new file mode 100644 index 0000000000..ebc0727414 --- /dev/null +++ b/integrations/astra/tests/test_filters_unit.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for filter conversion + +Integration tests for filters are included in test_document_store.py +""" + +import pytest +from haystack.errors import FilterError + +from haystack_integrations.document_stores.astra.filters import ( + _convert_filters, + _normalize_filters, + _normalize_ranges, + _parse_comparison_condition, + _parse_logical_condition, +) + + +class TestConvertFilters: + @pytest.mark.parametrize("empty", [None, {}]) + def test_empty_returns_none(self, empty): + assert _convert_filters(empty) is None + + @pytest.mark.parametrize( + "filters,expected", + [ + ( + {"field": "meta.year", "operator": "==", "value": 2024}, + {"meta.year": {"$eq": 2024}}, + ), + ( + {"field": "meta.year", "operator": "!=", "value": 2024}, + {"meta.year": {"$ne": 2024}}, + ), + ( + {"field": "meta.score", "operator": ">", "value": 1}, + {"meta.score": {"$gt": 1}}, + ), + ( + {"field": "meta.score", "operator": ">=", "value": 1}, + {"meta.score": {"$gte": 1}}, + ), + ( + {"field": "meta.score", "operator": "<", "value": 1}, + {"meta.score": {"$lt": 1}}, + ), + ( + {"field": "meta.score", "operator": "<=", "value": 1}, + {"meta.score": {"$lte": 1}}, + ), + ( + {"field": "meta.tag", "operator": "in", "value": ["a", "b"]}, + {"meta.tag": {"$in": ["a", "b"]}}, + ), + ( + {"field": "meta.tag", "operator": "not in", "value": ["a", "b"]}, + {"meta.tag": {"$nin": ["a", "b"]}}, + ), + ], + ) + def test_comparison_operators(self, filters, expected): + assert _convert_filters(filters) == expected + + def test_id_field_is_renamed_to_underscore_id(self): + filters = {"field": "id", "operator": "==", "value": "abc"} + assert _convert_filters(filters) == {"_id": {"$eq": "abc"}} + + def test_in_operator_with_non_list_raises(self): + filters = {"field": "meta.tag", "operator": "in", "value": "not-a-list"} + with pytest.raises(FilterError, match=r"\$in operator must have `ARRAY`"): + _convert_filters(filters) + + def test_logical_and(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.a", "operator": "==", "value": 1}, + {"field": "meta.b", "operator": "==", "value": 2}, + ], + } + result = _convert_filters(filters) + assert result == { + "$and": [ + {"meta.a": {"$eq": 1}}, + {"meta.b": {"$eq": 2}}, + ] + } + + def test_logical_or(self): + filters = { + "operator": "OR", + "conditions": [ + {"field": "meta.a", "operator": "==", "value": 1}, + {"field": "meta.b", "operator": "==", "value": 2}, + ], + } + result = _convert_filters(filters) + assert result == { + "$or": [ + {"meta.a": {"$eq": 1}}, + {"meta.b": {"$eq": 2}}, + ] + } + + +class TestNormalizeFilters: + def test_non_dict_raises(self): + with pytest.raises(FilterError, match="Filters must be a dictionary"): + _normalize_filters("not_a_dict") + + +class TestParseLogicalCondition: + def test_missing_operator_raises(self): + with pytest.raises(FilterError, match="'operator' key missing"): + _parse_logical_condition({"conditions": []}) + + def test_missing_conditions_raises(self): + with pytest.raises(FilterError, match="'conditions' key missing"): + _parse_logical_condition({"operator": "AND"}) + + def test_unknown_operator_raises(self): + with pytest.raises(FilterError, match="Unknown operator"): + _parse_logical_condition( + { + "operator": "XOR", + "conditions": [{"field": "a", "operator": "==", "value": 1}], + } + ) + + +class TestParseComparisonCondition: + @pytest.mark.parametrize( + "condition,err", + [ + ({"operator": "==", "value": 1}, "'field' key missing"), + ({"field": "a", "value": 1}, "'operator' key missing"), + ({"field": "a", "operator": "=="}, "'value' key missing"), + ], + ) + def test_missing_keys_raise(self, condition, err): + with pytest.raises(FilterError, match=err): + _parse_comparison_condition(condition) + + +class TestNormalizeRanges: + def test_no_ranges_returns_unchanged(self): + conditions = [ + {"meta.a": {"$eq": 1}}, + {"meta.b": {"$eq": 2}}, + ] + assert _normalize_ranges(conditions) == conditions + + def test_merges_range_conditions_on_same_field(self): + conditions = [ + {"range": {"date": {"lt": "2021-01-01"}}}, + {"range": {"date": {"gte": "2015-01-01"}}}, + ] + result = _normalize_ranges(conditions) + assert result == [{"range": {"date": {"lt": "2021-01-01", "gte": "2015-01-01"}}}] + + def test_keeps_non_range_conditions_when_merging(self): + conditions = [ + {"meta.a": {"$eq": 1}}, + {"range": {"date": {"lt": "2021-01-01"}}}, + {"range": {"date": {"gte": "2015-01-01"}}}, + ] + result = _normalize_ranges(conditions) + assert {"meta.a": {"$eq": 1}} in result + assert {"range": {"date": {"lt": "2021-01-01", "gte": "2015-01-01"}}} in result diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index d8ad55df8e..874bcc3778 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -11,38 +11,64 @@ from haystack_integrations.document_stores.astra import AstraDocumentStore -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_retriever_init(*_): - ds = AstraDocumentStore() - retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace") +@pytest.fixture +def mocked_store(monkeypatch): + monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "fake-token") + monkeypatch.setenv("ASTRA_DB_API_ENDPOINT", "http://fake-url.apps.astra.datastax.com") + with patch("haystack_integrations.document_stores.astra.document_store.AstraClient"): + yield AstraDocumentStore() + + +def _serialized_retriever(*, include_filter_policy: bool = True) -> dict: + init_parameters = { + "filters": {"bar": "baz"}, + "top_k": 42, + "document_store": { + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", + "init_parameters": { + "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, + "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, + "collection_name": "documents", + "embedding_dimension": 768, + "duplicates_policy": "NONE", + "similarity": "cosine", + }, + }, + } + if include_filter_policy: + init_parameters["filter_policy"] = "replace" + return { + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", + "init_parameters": init_parameters, + } + + +def test_retriever_init(mocked_store): + retriever = AstraEmbeddingRetriever(mocked_store, filters={"foo": "bar"}, top_k=99, filter_policy="replace") assert retriever.filters == {"foo": "bar"} assert retriever.top_k == 99 - assert retriever.document_store == ds + assert retriever.document_store == mocked_store assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=FilterPolicy.MERGE) + retriever = AstraEmbeddingRetriever( + mocked_store, filters={"foo": "bar"}, top_k=99, filter_policy=FilterPolicy.MERGE + ) assert retriever.filter_policy == FilterPolicy.MERGE with pytest.raises(ValueError): - AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") + AstraEmbeddingRetriever(mocked_store, filters={"foo": "bar"}, top_k=99, filter_policy="unknown") with pytest.raises(ValueError): - AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy=None) + AstraEmbeddingRetriever(mocked_store, filters={"foo": "bar"}, top_k=99, filter_policy=None) + +def test_retriever_init_rejects_non_astra_document_store(): + with pytest.raises(Exception, match="document_store must be an instance of AstraDocumentStore"): + AstraEmbeddingRetriever(document_store="not-a-store") # type: ignore[arg-type] -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_retriever_to_json(*_): - ds = AstraDocumentStore() - retriever = AstraEmbeddingRetriever(ds, filters={"foo": "bar"}, top_k=99) +def test_retriever_to_dict(mocked_store): + retriever = AstraEmbeddingRetriever(mocked_store, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", "init_parameters": { @@ -65,116 +91,60 @@ def test_retriever_to_json(*_): } -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_retriever_from_json(*_): - data = { - "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", - "init_parameters": { - "filters": {"bar": "baz"}, - "top_k": 42, - "filter_policy": "replace", - "document_store": { - "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", - "init_parameters": { - "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, - "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, - "collection_name": "documents", - "embedding_dimension": 768, - "duplicates_policy": "NONE", - "similarity": "cosine", - }, - }, - }, - } - retriever = AstraEmbeddingRetriever.from_dict(data) +@pytest.mark.parametrize("include_filter_policy", [True, False]) +def test_retriever_from_dict(mocked_store, include_filter_policy): # noqa: ARG001 + retriever = AstraEmbeddingRetriever.from_dict(_serialized_retriever(include_filter_policy=include_filter_policy)) assert retriever.top_k == 42 assert retriever.filters == {"bar": "baz"} + # filter_policy defaults to REPLACE when absent + assert retriever.filter_policy == FilterPolicy.REPLACE -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") -def test_retriever_from_json_no_filter_policy(*_): - data = { - "type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever", - "init_parameters": { - "filters": {"bar": "baz"}, - "top_k": 42, - "document_store": { - "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", - "init_parameters": { - "api_endpoint": {"type": "env_var", "env_vars": ["ASTRA_DB_API_ENDPOINT"], "strict": True}, - "token": {"type": "env_var", "env_vars": ["ASTRA_DB_APPLICATION_TOKEN"], "strict": True}, - "collection_name": "documents", - "embedding_dimension": 768, - "duplicates_policy": "NONE", - "similarity": "cosine", - }, - }, - }, - } - retriever = AstraEmbeddingRetriever.from_dict(data) - assert retriever.top_k == 42 - assert retriever.filters == {"bar": "baz"} - assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE +def test_run_uses_runtime_top_k_and_filters(mocked_store): + mock_doc = Document(content="test", id="1") + with patch.object(mocked_store, "search", return_value=[mock_doc]) as mocked_search: + retriever = AstraEmbeddingRetriever( + mocked_store, top_k=5, filters={"lang": "en"}, filter_policy=FilterPolicy.REPLACE + ) + result = retriever.run(query_embedding=[0.1] * 768, filters={"year": 2024}, top_k=3) + assert result == {"documents": [mock_doc]} + mocked_search.assert_called_once_with([0.1] * 768, 3, filters={"year": 2024}) -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") @pytest.mark.asyncio -async def test_run_async(*_): - ds = AstraDocumentStore() +async def test_run_async(mocked_store): mock_doc = Document(content="test", id="1") - with patch.object(ds, "search", return_value=[mock_doc]): - retriever = AstraEmbeddingRetriever(ds, top_k=5) + with patch.object(mocked_store, "search", return_value=[mock_doc]): + retriever = AstraEmbeddingRetriever(mocked_store, top_k=5) result = await retriever.run_async(query_embedding=[0.1] * 768) assert result["documents"] == [mock_doc] - ds.search.assert_called_once() - call_args = ds.search.call_args[0] - assert call_args[0] == [0.1] * 768 - assert call_args[1] == 5 - assert ds.search.call_args.kwargs["filters"] == {} - - -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") + call_args = mocked_store.search.call_args + assert call_args.args == ([0.1] * 768, 5) + assert call_args.kwargs == {"filters": {}} + + @pytest.mark.asyncio -async def test_run_async_filters_replace(*_): - ds = AstraDocumentStore() +async def test_run_async_filters_replace(mocked_store): mock_doc = Document(content="test", id="1") - with patch.object(ds, "search", return_value=[mock_doc]): - retriever = AstraEmbeddingRetriever(ds, top_k=5, filters={"lang": "en"}, filter_policy=FilterPolicy.REPLACE) + with patch.object(mocked_store, "search", return_value=[mock_doc]): + retriever = AstraEmbeddingRetriever( + mocked_store, top_k=5, filters={"lang": "en"}, filter_policy=FilterPolicy.REPLACE + ) await retriever.run_async(query_embedding=[0.1] * 768, filters={"year": 2024}) - assert ds.search.call_args.kwargs["filters"] == {"year": 2024} + assert mocked_store.search.call_args.kwargs["filters"] == {"year": 2024} -@patch.dict( - "os.environ", - {"ASTRA_DB_APPLICATION_TOKEN": "fake-token", "ASTRA_DB_API_ENDPOINT": "http://fake-url.apps.astra.datastax.com"}, -) -@patch("haystack_integrations.document_stores.astra.document_store.AstraClient") @pytest.mark.asyncio -async def test_run_async_filters_merge(*_): - ds = AstraDocumentStore() +async def test_run_async_filters_merge(mocked_store): mock_doc = Document(content="test", id="1") init_filters = {"field": "lang", "operator": "==", "value": "en"} runtime_filters = {"field": "year", "operator": "==", "value": 2024} - with patch.object(ds, "search", return_value=[mock_doc]): - retriever = AstraEmbeddingRetriever(ds, top_k=5, filters=init_filters, filter_policy=FilterPolicy.MERGE) + with patch.object(mocked_store, "search", return_value=[mock_doc]): + retriever = AstraEmbeddingRetriever( + mocked_store, top_k=5, filters=init_filters, filter_policy=FilterPolicy.MERGE + ) await retriever.run_async(query_embedding=[0.1] * 768, filters=runtime_filters) - merged = ds.search.call_args.kwargs["filters"] + merged = mocked_store.search.call_args.kwargs["filters"] assert merged["operator"] == "AND" assert init_filters in merged["conditions"] assert runtime_filters in merged["conditions"]