Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 332 additions & 7 deletions sdk/python/feast/infra/online_stores/eg_valkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@

from feast import Entity, FeatureView, RepoConfig, utils
from feast.field import Field
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.key_encoding_utils import (
deserialize_entity_key,
serialize_entity_key,
)
from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
Expand All @@ -55,8 +58,9 @@
from valkey import Valkey
from valkey import asyncio as valkey_asyncio
from valkey.cluster import ClusterNode, ValkeyCluster
from valkey.commands.search.field import VectorField
from valkey.commands.search.field import TagField, VectorField
from valkey.commands.search.indexDefinition import IndexDefinition, IndexType
from valkey.commands.search.query import Query
from valkey.sentinel import Sentinel
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError
Expand Down Expand Up @@ -502,10 +506,14 @@ def _create_vector_index_if_not_exists(
online_store_config.vector_index_hnsw_ef_runtime
)

# Create the index with single vector field
# Create the index with vector field and project tag for filtering
# __project__ TAG field enables filtering by project in hybrid queries
try:
client.ft(index_name).create_index(
fields=[VectorField(field_name, algorithm, attributes)],
fields=[
VectorField(field_name, algorithm, attributes),
TagField("__project__"),
],
definition=definition,
)
logger.info(f"Created vector index {index_name} for field {field_name}")
Expand Down Expand Up @@ -703,9 +711,12 @@ def online_write_batch(
# flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin
prev_event_timestamps = [i[0] for i in prev_event_timestamps]

for valkey_key_bin, prev_event_time, (_, values, timestamp, _) in zip(
keys, prev_event_timestamps, data
):
for valkey_key_bin, prev_event_time, (
entity_key,
values,
timestamp,
_,
) in zip(keys, prev_event_timestamps, data):
event_time_seconds = int(utils.make_tzaware(timestamp).timestamp())

# ignore if event_timestamp is before the event features that are currently in the feature store
Expand All @@ -722,6 +733,12 @@ def online_write_batch(
ts.seconds = event_time_seconds
entity_hset = dict()
entity_hset[ts_key] = ts.SerializeToString()
# Store project and entity key for vector search
entity_hset["__project__"] = project.encode()
entity_hset["__entity_key__"] = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)

for feature_name, val in values.items():
if feature_name in vector_fields:
Expand Down Expand Up @@ -979,3 +996,311 @@ def _get_features_for_entity(
else:
timestamp = datetime.fromtimestamp(res_ts.seconds, tz=timezone.utc)
return timestamp, res

def retrieve_online_documents_v2(
self,
config: RepoConfig,
table: FeatureView,
requested_features: List[str],
embedding: Optional[List[float]],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
]:
"""
Retrieve documents using vector similarity search from Valkey.

Args:
config: Feast configuration object
table: FeatureView to search
requested_features: List of feature names to return
embedding: Query embedding vector
top_k: Number of results to return
distance_metric: Optional override for distance metric (COSINE, L2, IP)
query_string: Not supported in V1 (reserved for future BM25 search)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query_string: Not supported in V1 (reserved for future BM25 search)
query_string: Not supported in V2(reserved for future BM25 search)


Returns:
List of tuples containing (timestamp, entity_key, features_dict)
"""
if embedding is None:
raise ValueError("embedding must be provided for vector search")

if query_string is not None:
raise NotImplementedError(
"Keyword search (query_string) is not yet supported for Valkey. "
"Only vector similarity search is available."
)

online_store_config = config.online_store
assert isinstance(online_store_config, EGValkeyOnlineStoreConfig)

client = self._get_client(online_store_config)
project = config.project

# Find the vector field to search against
vector_field = self._get_vector_field_for_search(table, requested_features)
if vector_field is None:
raise ValueError(
f"No vector field found in FeatureView {table.name}. "
"Ensure the FeatureView has a field with vector_index=True."
)

# Determine distance metric
metric = distance_metric or vector_field.vector_search_metric or "COSINE"

# Serialize query embedding to bytes
embedding_bytes = self._serialize_embedding_for_search(embedding, vector_field)

# Build and execute FT.SEARCH query
index_name = _get_vector_index_name(project, table.name, vector_field.name)
search_results = self._execute_vector_search(
client=client,
index_name=index_name,
project=project,
vector_field_name=vector_field.name,
embedding_bytes=embedding_bytes,
top_k=top_k,
metric=metric,
)

if not search_results:
return []

# Fetch features for each result using pipeline HMGET
return self._fetch_features_for_search_results(
client=client,
config=config,
table=table,
requested_features=requested_features,
search_results=search_results,
)

def _get_vector_field_for_search(
self,
table: FeatureView,
requested_features: Optional[List[str]],
) -> Optional[Field]:
"""Find the vector field to use for search."""
vector_fields = [f for f in table.features if f.vector_index]

if not vector_fields:
return None

# If requested_features specified, prefer a vector field from that list
if requested_features:
# Convert to set for O(1) lookup instead of O(n) list search
requested_set = set(requested_features)
for f in vector_fields:
if f.name in requested_set:
return f

# Default to first vector field
return vector_fields[0]

def _serialize_embedding_for_search(
self,
embedding: List[float],
vector_field: Field,
) -> bytes:
"""Serialize query embedding to bytes matching the field's dtype."""
# Validate embedding dimension matches field configuration
if len(embedding) != vector_field.vector_length:
raise ValueError(
f"Embedding dimension {len(embedding)} does not match "
f"vector field '{vector_field.name}' dimension {vector_field.vector_length}"
)

if vector_field.dtype == Array(Float64):
return np.array(embedding, dtype=np.float64).tobytes()
else:
# Default to float32
return np.array(embedding, dtype=np.float32).tobytes()

def _execute_vector_search(
self,
client: Union[Valkey, ValkeyCluster],
index_name: str,
project: str,
vector_field_name: str,
embedding_bytes: bytes,
top_k: int,
metric: str,
) -> List[Tuple[bytes, float]]:
"""
Execute FT.SEARCH with KNN query.

Returns:
List of (doc_key, distance) tuples
"""
# Escape double quotes in project name for DIALECT 2 quoted tag syntax
# This handles special characters like hyphens which would otherwise
# be interpreted as operators (e.g., "my-project" -> "my NOT project")
escaped_project = project.replace('"', '\\"')

# Build KNN query with project filter using quoted tag syntax (DIALECT 2)
query_str = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we allow hyphens in project name? If yes, then Iyou would have to escape it. Redisearch interprets it as negation.

f'(@__project__:{{"{escaped_project}"}})'
f"=>[KNN {top_k} @{vector_field_name} $vec AS __distance__]"
)

# Determine sort order based on metric:
# - COSINE, L2: lower distance = more similar → ascending
# - IP (Inner Product): higher score = more similar → descending
sort_ascending = metric.upper() != "IP"

query = (
Query(query_str)
.return_fields("__distance__")
.sort_by("__distance__", asc=sort_ascending)
.paging(0, top_k)
.dialect(2)
)

try:
results = client.ft(index_name).search(
query,
query_params={"vec": embedding_bytes},
)
except ResponseError as e:
if "no such index" in str(e).lower():
raise ValueError(
f"Vector index '{index_name}' does not exist. "
"Ensure data has been materialized with 'feast materialize'."
)
raise

# Parse results: extract doc keys and distances
search_results = []
for doc in results.docs:
doc_key = doc.id.encode() if isinstance(doc.id, str) else doc.id
# Default to inf (worst distance) if __distance__ is missing
# 0.0 would incorrectly indicate a perfect match
distance = float(getattr(doc, "__distance__", float("inf")))
search_results.append((doc_key, distance))

return search_results

def _fetch_features_for_search_results(
self,
client: Union[Valkey, ValkeyCluster],
config: RepoConfig,
table: FeatureView,
requested_features: List[str],
search_results: List[Tuple[bytes, float]],
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
]:
"""
Fetch features for search results using pipeline HMGET.

This is the second step of two-step retrieval:
1. FT.SEARCH returns doc keys and distances
2. HMGET fetches the actual feature values
"""
# Pre-compute mappings once (avoid repeated dict/hash operations in loops)
vector_fields_dict = {f.name: f for f in table.features if f.vector_index}

# Build feature_name -> hset_key mapping and hset_keys list in single pass
feature_to_hset_key: Dict[str, Any] = {}
hset_keys = []
for feature_name in requested_features:
if feature_name in vector_fields_dict:
hset_key = feature_name
else:
hset_key = _mmh3(f"{table.name}:{feature_name}")
feature_to_hset_key[feature_name] = hset_key
hset_keys.append(hset_key)

# Add timestamp and entity key
ts_key = f"_ts:{table.name}"
hset_keys.append(ts_key)
hset_keys.append("__entity_key__")

# Extract doc_keys and distances in single pass
doc_keys = []
distances = {}
for doc_key, dist in search_results:
doc_keys.append(doc_key)
distances[doc_key] = dist

# Pipeline HMGET for all results (single round-trip to Valkey)
with client.pipeline(transaction=False) as pipe:
for doc_key in doc_keys:
key_str = doc_key.decode() if isinstance(doc_key, bytes) else doc_key
pipe.hmget(key_str, hset_keys)
fetched_values = pipe.execute()

# Pre-fetch serialization version once
entity_key_serialization_version = config.entity_key_serialization_version

# Build result list
results: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
] = []

for doc_key, values in zip(doc_keys, fetched_values):
# Parse values into dict
val_dict = dict(zip(hset_keys, values))

# Parse timestamp
timestamp = None
ts_val = val_dict.get(ts_key)
if ts_val:
ts_proto = Timestamp()
ts_proto.ParseFromString(bytes(ts_val))
timestamp = datetime.fromtimestamp(ts_proto.seconds, tz=timezone.utc)

# Parse entity key
entity_key_proto = None
entity_key_bytes = val_dict.get("__entity_key__")
if entity_key_bytes:
entity_key_proto = deserialize_entity_key(
bytes(entity_key_bytes),
entity_key_serialization_version=entity_key_serialization_version,
)

# Build feature dict with pre-allocated capacity hint
feature_dict: Dict[str, ValueProto] = {}

# Add distance as a feature
distance_proto = ValueProto()
distance_proto.double_val = distances[doc_key]
feature_dict["distance"] = distance_proto

# Parse requested features using pre-computed mappings
for feature_name in requested_features:
hset_key = feature_to_hset_key[feature_name]
val_bin = val_dict.get(hset_key)

if not val_bin:
feature_dict[feature_name] = ValueProto()
continue

if feature_name in vector_fields_dict:
# Vector field: deserialize from raw bytes
feature_dict[feature_name] = _deserialize_vector_from_bytes(
bytes(val_bin), vector_fields_dict[feature_name]
)
else:
# Regular field: parse protobuf
val = ValueProto()
val.ParseFromString(bytes(val_bin))
feature_dict[feature_name] = val

results.append((timestamp, entity_key_proto, feature_dict))

return results
Loading
Loading