diff --git a/sdk/python/feast/infra/online_stores/eg_valkey.py b/sdk/python/feast/infra/online_stores/eg_valkey.py index 02abe1eb25..deadf9c41f 100644 --- a/sdk/python/feast/infra/online_stores/eg_valkey.py +++ b/sdk/python/feast/infra/online_stores/eg_valkey.py @@ -38,7 +38,10 @@ from feast import Entity, FeatureView, RepoConfig, utils from feast.field import Field -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.key_encoding_utils import ( + deserialize_entity_key, + serialize_entity_key, +) from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -55,8 +58,9 @@ from valkey import Valkey from valkey import asyncio as valkey_asyncio from valkey.cluster import ClusterNode, ValkeyCluster - from valkey.commands.search.field import VectorField + from valkey.commands.search.field import TagField, VectorField from valkey.commands.search.indexDefinition import IndexDefinition, IndexType + from valkey.commands.search.query import Query from valkey.sentinel import Sentinel except ImportError as e: from feast.errors import FeastExtrasDependencyImportError @@ -502,10 +506,14 @@ def _create_vector_index_if_not_exists( online_store_config.vector_index_hnsw_ef_runtime ) - # Create the index with single vector field + # Create the index with vector field and project tag for filtering + # __project__ TAG field enables filtering by project in hybrid queries try: client.ft(index_name).create_index( - fields=[VectorField(field_name, algorithm, attributes)], + fields=[ + VectorField(field_name, algorithm, attributes), + TagField("__project__"), + ], definition=definition, ) logger.info(f"Created vector index {index_name} for field {field_name}") @@ -703,9 +711,12 @@ def online_write_batch( # flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin prev_event_timestamps = [i[0] for i in prev_event_timestamps] - for valkey_key_bin, prev_event_time, (_, values, timestamp, _) in zip( - keys, prev_event_timestamps, data - ): + for valkey_key_bin, prev_event_time, ( + entity_key, + values, + timestamp, + _, + ) in zip(keys, prev_event_timestamps, data): event_time_seconds = int(utils.make_tzaware(timestamp).timestamp()) # ignore if event_timestamp is before the event features that are currently in the feature store @@ -722,6 +733,12 @@ def online_write_batch( ts.seconds = event_time_seconds entity_hset = dict() entity_hset[ts_key] = ts.SerializeToString() + # Store project and entity key for vector search + entity_hset["__project__"] = project.encode() + entity_hset["__entity_key__"] = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) for feature_name, val in values.items(): if feature_name in vector_fields: @@ -979,3 +996,311 @@ def _get_features_for_entity( else: timestamp = datetime.fromtimestamp(res_ts.seconds, tz=timezone.utc) return timestamp, res + + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search from Valkey. + + Args: + config: Feast configuration object + table: FeatureView to search + requested_features: List of feature names to return + embedding: Query embedding vector + top_k: Number of results to return + distance_metric: Optional override for distance metric (COSINE, L2, IP) + query_string: Not supported in V1 (reserved for future BM25 search) + + Returns: + List of tuples containing (timestamp, entity_key, features_dict) + """ + if embedding is None: + raise ValueError("embedding must be provided for vector search") + + if query_string is not None: + raise NotImplementedError( + "Keyword search (query_string) is not yet supported for Valkey. " + "Only vector similarity search is available." + ) + + online_store_config = config.online_store + assert isinstance(online_store_config, EGValkeyOnlineStoreConfig) + + client = self._get_client(online_store_config) + project = config.project + + # Find the vector field to search against + vector_field = self._get_vector_field_for_search(table, requested_features) + if vector_field is None: + raise ValueError( + f"No vector field found in FeatureView {table.name}. " + "Ensure the FeatureView has a field with vector_index=True." + ) + + # Determine distance metric + metric = distance_metric or vector_field.vector_search_metric or "COSINE" + + # Serialize query embedding to bytes + embedding_bytes = self._serialize_embedding_for_search(embedding, vector_field) + + # Build and execute FT.SEARCH query + index_name = _get_vector_index_name(project, table.name, vector_field.name) + search_results = self._execute_vector_search( + client=client, + index_name=index_name, + project=project, + vector_field_name=vector_field.name, + embedding_bytes=embedding_bytes, + top_k=top_k, + metric=metric, + ) + + if not search_results: + return [] + + # Fetch features for each result using pipeline HMGET + return self._fetch_features_for_search_results( + client=client, + config=config, + table=table, + requested_features=requested_features, + search_results=search_results, + ) + + def _get_vector_field_for_search( + self, + table: FeatureView, + requested_features: Optional[List[str]], + ) -> Optional[Field]: + """Find the vector field to use for search.""" + vector_fields = [f for f in table.features if f.vector_index] + + if not vector_fields: + return None + + # If requested_features specified, prefer a vector field from that list + if requested_features: + # Convert to set for O(1) lookup instead of O(n) list search + requested_set = set(requested_features) + for f in vector_fields: + if f.name in requested_set: + return f + + # Default to first vector field + return vector_fields[0] + + def _serialize_embedding_for_search( + self, + embedding: List[float], + vector_field: Field, + ) -> bytes: + """Serialize query embedding to bytes matching the field's dtype.""" + # Validate embedding dimension matches field configuration + if len(embedding) != vector_field.vector_length: + raise ValueError( + f"Embedding dimension {len(embedding)} does not match " + f"vector field '{vector_field.name}' dimension {vector_field.vector_length}" + ) + + if vector_field.dtype == Array(Float64): + return np.array(embedding, dtype=np.float64).tobytes() + else: + # Default to float32 + return np.array(embedding, dtype=np.float32).tobytes() + + def _execute_vector_search( + self, + client: Union[Valkey, ValkeyCluster], + index_name: str, + project: str, + vector_field_name: str, + embedding_bytes: bytes, + top_k: int, + metric: str, + ) -> List[Tuple[bytes, float]]: + """ + Execute FT.SEARCH with KNN query. + + Returns: + List of (doc_key, distance) tuples + """ + # Escape double quotes in project name for DIALECT 2 quoted tag syntax + # This handles special characters like hyphens which would otherwise + # be interpreted as operators (e.g., "my-project" -> "my NOT project") + escaped_project = project.replace('"', '\\"') + + # Build KNN query with project filter using quoted tag syntax (DIALECT 2) + query_str = ( + f'(@__project__:{{"{escaped_project}"}})' + f"=>[KNN {top_k} @{vector_field_name} $vec AS __distance__]" + ) + + # Determine sort order based on metric: + # - COSINE, L2: lower distance = more similar → ascending + # - IP (Inner Product): higher score = more similar → descending + sort_ascending = metric.upper() != "IP" + + query = ( + Query(query_str) + .return_fields("__distance__") + .sort_by("__distance__", asc=sort_ascending) + .paging(0, top_k) + .dialect(2) + ) + + try: + results = client.ft(index_name).search( + query, + query_params={"vec": embedding_bytes}, + ) + except ResponseError as e: + if "no such index" in str(e).lower(): + raise ValueError( + f"Vector index '{index_name}' does not exist. " + "Ensure data has been materialized with 'feast materialize'." + ) + raise + + # Parse results: extract doc keys and distances + search_results = [] + for doc in results.docs: + doc_key = doc.id.encode() if isinstance(doc.id, str) else doc.id + # Default to inf (worst distance) if __distance__ is missing + # 0.0 would incorrectly indicate a perfect match + distance = float(getattr(doc, "__distance__", float("inf"))) + search_results.append((doc_key, distance)) + + return search_results + + def _fetch_features_for_search_results( + self, + client: Union[Valkey, ValkeyCluster], + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + search_results: List[Tuple[bytes, float]], + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Fetch features for search results using pipeline HMGET. + + This is the second step of two-step retrieval: + 1. FT.SEARCH returns doc keys and distances + 2. HMGET fetches the actual feature values + """ + # Pre-compute mappings once (avoid repeated dict/hash operations in loops) + vector_fields_dict = {f.name: f for f in table.features if f.vector_index} + + # Build feature_name -> hset_key mapping and hset_keys list in single pass + feature_to_hset_key: Dict[str, Any] = {} + hset_keys = [] + for feature_name in requested_features: + if feature_name in vector_fields_dict: + hset_key = feature_name + else: + hset_key = _mmh3(f"{table.name}:{feature_name}") + feature_to_hset_key[feature_name] = hset_key + hset_keys.append(hset_key) + + # Add timestamp and entity key + ts_key = f"_ts:{table.name}" + hset_keys.append(ts_key) + hset_keys.append("__entity_key__") + + # Extract doc_keys and distances in single pass + doc_keys = [] + distances = {} + for doc_key, dist in search_results: + doc_keys.append(doc_key) + distances[doc_key] = dist + + # Pipeline HMGET for all results (single round-trip to Valkey) + with client.pipeline(transaction=False) as pipe: + for doc_key in doc_keys: + key_str = doc_key.decode() if isinstance(doc_key, bytes) else doc_key + pipe.hmget(key_str, hset_keys) + fetched_values = pipe.execute() + + # Pre-fetch serialization version once + entity_key_serialization_version = config.entity_key_serialization_version + + # Build result list + results: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + + for doc_key, values in zip(doc_keys, fetched_values): + # Parse values into dict + val_dict = dict(zip(hset_keys, values)) + + # Parse timestamp + timestamp = None + ts_val = val_dict.get(ts_key) + if ts_val: + ts_proto = Timestamp() + ts_proto.ParseFromString(bytes(ts_val)) + timestamp = datetime.fromtimestamp(ts_proto.seconds, tz=timezone.utc) + + # Parse entity key + entity_key_proto = None + entity_key_bytes = val_dict.get("__entity_key__") + if entity_key_bytes: + entity_key_proto = deserialize_entity_key( + bytes(entity_key_bytes), + entity_key_serialization_version=entity_key_serialization_version, + ) + + # Build feature dict with pre-allocated capacity hint + feature_dict: Dict[str, ValueProto] = {} + + # Add distance as a feature + distance_proto = ValueProto() + distance_proto.double_val = distances[doc_key] + feature_dict["distance"] = distance_proto + + # Parse requested features using pre-computed mappings + for feature_name in requested_features: + hset_key = feature_to_hset_key[feature_name] + val_bin = val_dict.get(hset_key) + + if not val_bin: + feature_dict[feature_name] = ValueProto() + continue + + if feature_name in vector_fields_dict: + # Vector field: deserialize from raw bytes + feature_dict[feature_name] = _deserialize_vector_from_bytes( + bytes(val_bin), vector_fields_dict[feature_name] + ) + else: + # Regular field: parse protobuf + val = ValueProto() + val.ParseFromString(bytes(val_bin)) + feature_dict[feature_name] = val + + results.append((timestamp, entity_key_proto, feature_dict)) + + return results diff --git a/sdk/python/tests/unit/infra/online_store/test_valkey.py b/sdk/python/tests/unit/infra/online_store/test_valkey.py index a52eaf063d..83a66759bd 100644 --- a/sdk/python/tests/unit/infra/online_store/test_valkey.py +++ b/sdk/python/tests/unit/infra/online_store/test_valkey.py @@ -926,6 +926,70 @@ def test_vector_field_with_negative_vector_length_raises_error( ) +class TestVectorIndexCreation: + """Tests for vector index creation with correct schema.""" + + def test_index_includes_project_tag_field( + self, valkey_online_store, repo_config_without_docker_connection_string + ): + """Test that index schema includes TagField for __project__ filtering.""" + from unittest.mock import MagicMock + + from valkey.exceptions import ResponseError + + fv = FeatureView( + name="test_with_project_tag", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id")], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ), + ], + ) + + vector_fields = {f.name: f for f in fv.features if f.vector_index} + + mock_client = MagicMock() + # Simulate index doesn't exist + mock_client.ft.return_value.info.side_effect = ResponseError("Unknown index") + + valkey_online_store._create_vector_index_if_not_exists( + mock_client, + repo_config_without_docker_connection_string, + fv, + vector_fields, + ) + + # Verify create_index was called + mock_client.ft.return_value.create_index.assert_called_once() + + # Get the fields argument + call_kwargs = mock_client.ft.return_value.create_index.call_args + fields = call_kwargs.kwargs.get("fields") or call_kwargs.args[0] + + # Verify we have both VectorField and TagField + field_types = [type(f).__name__ for f in fields] + assert "VectorField" in field_types, "Index should include VectorField" + assert "TagField" in field_types, ( + "Index should include TagField for __project__" + ) + + # Verify TagField is for __project__ + tag_fields = [f for f in fields if type(f).__name__ == "TagField"] + assert len(tag_fields) == 1 + assert tag_fields[0].name == "__project__" + + # ============================================================================ # Vector Support Integration Tests (Docker Required) # ============================================================================ @@ -1044,6 +1108,14 @@ def test_valkey_online_write_batch_with_vector_field( vector = np.frombuffer(embedding_bytes, dtype=np.float32) np.testing.assert_array_almost_equal(vector, [0.1, 0.2, 0.3, 0.4], decimal=5) + # Verify __project__ is stored for vector search filtering + assert b"__project__" in stored_data + # Should be stored as string (valkey-py encodes to bytes, but value should match project) + assert stored_data[b"__project__"] == repo_config.project.encode() + + # Verify __entity_key__ is stored for entity key retrieval + assert b"__entity_key__" in stored_data + @pytest.mark.docker def test_valkey_online_read_with_vector_field( @@ -1261,3 +1333,456 @@ def test_valkey_online_read_with_requested_features_mixed( # Verify string value assert features["item_name"].string_val == "item_2" + + +class TestGetVectorFieldForSearch: + """Tests for _get_vector_field_for_search helper method.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + def test_returns_vector_field_from_requested_features( + self, feature_view_with_vector + ): + """Test that vector field is returned when in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, + requested_features=["embedding", "scalar_feature"], + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_first_vector_field_when_not_in_requested( + self, feature_view_with_vector + ): + """Test that first vector field is returned when not in requested_features.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_with_vector, requested_features=["scalar_feature"] + ) + assert result is not None + assert result.name == "embedding" + + def test_returns_none_for_no_vector_fields(self, feature_view_no_vector): + """Test that None is returned when no vector fields exist.""" + store = EGValkeyOnlineStore() + result = store._get_vector_field_for_search( + feature_view_no_vector, requested_features=["scalar_feature"] + ) + assert result is None + + +class TestSerializeEmbeddingForSearch: + """Tests for _serialize_embedding_for_search helper method.""" + + @pytest.fixture + def float32_vector_field(self): + """Create a Float32 vector field.""" + return Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + ) + + @pytest.fixture + def float64_vector_field(self): + """Create a Float64 vector field.""" + return Field( + name="embedding", + dtype=Array(Float64), + vector_index=True, + vector_length=4, + ) + + def test_serializes_to_float32_bytes(self, float32_vector_field): + """Test that embedding is serialized to float32 bytes.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float32_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 floats * 4 bytes each = 16 bytes) + assert len(result) == 16 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float32) + np.testing.assert_array_almost_equal(arr, embedding, decimal=5) + + def test_serializes_to_float64_bytes(self, float64_vector_field): + """Test that embedding is serialized to float64 bytes for Float64 fields.""" + store = EGValkeyOnlineStore() + embedding = [0.1, 0.2, 0.3, 0.4] + result = store._serialize_embedding_for_search(embedding, float64_vector_field) + + # Verify it's bytes + assert isinstance(result, bytes) + + # Verify length (4 doubles * 8 bytes each = 32 bytes) + assert len(result) == 32 + + # Verify values can be deserialized back + arr = np.frombuffer(result, dtype=np.float64) + np.testing.assert_array_almost_equal(arr, embedding, decimal=10) + + def test_raises_error_on_dimension_mismatch(self, float32_vector_field): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # Field expects 4 dimensions, but we provide 3 + embedding = [0.1, 0.2, 0.3] + with pytest.raises(ValueError, match="dimension .* does not match"): + store._serialize_embedding_for_search(embedding, float32_vector_field) + + +class TestRetrieveOnlineDocumentsV2Validation: + """Tests for retrieve_online_documents_v2 input validation.""" + + @pytest.fixture + def feature_view_with_vector(self): + """Create a FeatureView with vector field for testing.""" + return FeatureView( + name="test_fv", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field( + name="embedding", + dtype=Array(Float32), + vector_index=True, + vector_length=4, + vector_search_metric="COSINE", + ), + ], + ) + + @pytest.fixture + def feature_view_no_vector(self): + """Create a FeatureView without vector fields.""" + return FeatureView( + name="test_fv_no_vector", + source=FileSource( + name="test_source", + path="test.parquet", + timestamp_field="event_timestamp", + ), + entities=[Entity(name="item_id", value_type=ValueType.INT64)], + ttl=timedelta(days=1), + schema=[ + Field(name="item_id", dtype=Int64), + Field(name="scalar_feature", dtype=Float32), + ], + ) + + @pytest.fixture + def repo_config(self): + """Create a minimal RepoConfig for testing.""" + return RepoConfig( + project="test_project", + provider="local", + registry="test_registry.db", + online_store=EGValkeyOnlineStoreConfig( + type="eg-valkey", + connection_string="localhost:6379", + ), + entity_key_serialization_version=3, + ) + + def test_raises_error_when_embedding_is_none( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding is None.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="embedding must be provided"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=None, + top_k=10, + ) + + def test_raises_error_when_query_string_provided( + self, repo_config, feature_view_with_vector + ): + """Test that NotImplementedError is raised when query_string is provided.""" + store = EGValkeyOnlineStore() + with pytest.raises(NotImplementedError, match="Keyword search"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + query_string="test query", + ) + + def test_raises_error_when_no_vector_field( + self, repo_config, feature_view_no_vector + ): + """Test that ValueError is raised when FeatureView has no vector fields.""" + store = EGValkeyOnlineStore() + with pytest.raises(ValueError, match="No vector field found"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_no_vector, + requested_features=["scalar_feature"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + def test_raises_error_when_dimension_mismatch( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when embedding dimension doesn't match field.""" + store = EGValkeyOnlineStore() + # feature_view_with_vector has vector_length=4, so 3-dim embedding should fail + with pytest.raises(ValueError, match="Embedding dimension .* does not match"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3], # Wrong dimension (3 instead of 4) + top_k=10, + ) + + def test_raises_error_when_index_does_not_exist( + self, repo_config, feature_view_with_vector + ): + """Test that ValueError is raised when vector index doesn't exist.""" + from unittest.mock import MagicMock, patch + + from valkey.exceptions import ResponseError + + store = EGValkeyOnlineStore() + + # Mock the client to simulate "no such index" error + mock_client = MagicMock() + mock_client.ft.return_value.search.side_effect = ResponseError("no such index") + + with patch.object(store, "_get_client", return_value=mock_client): + with pytest.raises(ValueError, match="does not exist.*materialize"): + store.retrieve_online_documents_v2( + config=repo_config, + table=feature_view_with_vector, + requested_features=["embedding"], + embedding=[0.1, 0.2, 0.3, 0.4], + top_k=10, + ) + + +class TestExecuteVectorSearch: + """Tests for _execute_vector_search helper method.""" + + @pytest.fixture + def store(self): + return EGValkeyOnlineStore() + + def test_project_name_with_hyphen_is_escaped(self, store): + """Test that project names with hyphens are properly escaped in queries.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="my-project", # Hyphen in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + # Verify the query was called + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # The query string should have quoted project name for DIALECT 2 + # This prevents hyphen from being interpreted as negation + assert '"my-project"' in query.query_string() + + def test_project_name_with_double_quote_is_escaped(self, store): + """Test that double quotes in project names are escaped.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project='my"project', # Double quote in project name + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + mock_client.ft.return_value.search.assert_called_once() + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # Double quote should be escaped + assert r"\"" in query.query_string() + + def test_sort_ascending_for_cosine_metric(self, store): + """Test that COSINE metric uses ascending sort (lower = better).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # COSINE should sort ascending (lower distance = more similar) + # Query._sortby is a SortbyField object with .args = [field, "ASC"/"DESC"] + assert query._sortby.args[0] == "__distance__" + assert query._sortby.args[1] == "ASC" + + def test_sort_ascending_for_l2_metric(self, store): + """Test that L2 metric uses ascending sort (lower = better).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="L2", + ) + + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # L2 should sort ascending (lower distance = more similar) + assert query._sortby.args[1] == "ASC" + + def test_sort_descending_for_ip_metric(self, store): + """Test that IP (Inner Product) metric uses descending sort (higher = better).""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_result = MagicMock() + mock_result.docs = [] + mock_client.ft.return_value.search.return_value = mock_result + + store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="IP", + ) + + call_args = mock_client.ft.return_value.search.call_args + query = call_args[0][0] + + # IP should sort descending (higher score = more similar) + assert query._sortby.args[1] == "DESC" + + def test_default_distance_is_infinity_not_zero(self, store): + """Test that missing __distance__ defaults to infinity, not 0.0.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_doc = MagicMock() + mock_doc.id = "test_key" + # Simulate missing __distance__ attribute + del mock_doc.__distance__ + + mock_result = MagicMock() + mock_result.docs = [mock_doc] + mock_client.ft.return_value.search.return_value = mock_result + + results = store._execute_vector_search( + client=mock_client, + index_name="test_index", + project="test_project", + vector_field_name="embedding", + embedding_bytes=b"\x00" * 16, + top_k=10, + metric="COSINE", + ) + + # Distance should default to infinity, not 0.0 + # 0.0 would incorrectly indicate a perfect match + assert len(results) == 1 + doc_key, distance = results[0] + assert distance == float("inf")