diff --git a/integrations/oracle/README.md b/integrations/oracle/README.md index e1afef02ba..3679e7a092 100644 --- a/integrations/oracle/README.md +++ b/integrations/oracle/README.md @@ -9,109 +9,6 @@ Haystack DocumentStore backed by [Oracle AI Vector Search](https://www.oracle.co --- -## Installation - -```bash -pip install oracle-haystack -``` - -Requires Python 3.10+ and Oracle Database 23ai (or later). No Oracle Instant Client is needed for direct TCP connections (thin mode). - -## Usage - -```python -from haystack.utils import Secret -from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore -from haystack_integrations.components.retrievers.oracle import OracleEmbeddingRetriever - -# Configure the connection -config = OracleConnectionConfig( - user=Secret.from_env_var("ORACLE_USER"), - password=Secret.from_env_var("ORACLE_PASSWORD"), - dsn=Secret.from_env_var("ORACLE_DSN"), -) - -# Create the document store -store = OracleDocumentStore( - connection_config=config, - table_name="my_documents", - embedding_dim=768, - distance_metric="COSINE", - create_table_if_not_exists=True, -) - -# Write documents -from haystack.dataclasses import Document -store.write_documents([ - Document(content="Oracle 23ai supports native vector search."), -]) - -# Retrieve by embedding -retriever = OracleEmbeddingRetriever(document_store=store, top_k=5) -results = retriever.run(query_embedding=[0.1] * 768) -print(results["documents"]) -``` - -### Connecting to Oracle Autonomous Database (ADB-S / wallet) - -```python -config = OracleConnectionConfig( - user=Secret.from_env_var("ORACLE_USER"), - password=Secret.from_env_var("ORACLE_PASSWORD"), - dsn=Secret.from_env_var("ORACLE_DSN"), - wallet_location="/path/to/wallet", - wallet_password=Secret.from_env_var("WALLET_PASSWORD"), -) -``` - -### Optional HNSW index - -Pass `create_index=True` when constructing the store to build an HNSW vector index, which dramatically speeds up approximate nearest-neighbour search on large collections: - -```python -store = OracleDocumentStore( - connection_config=config, - table_name="my_documents", - embedding_dim=768, - create_index=True, - hnsw_neighbors=32, - hnsw_ef_construction=200, - hnsw_accuracy=95, -) -``` - ## Contributing Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md). - -### Running tests - -#### Unit tests - -```bash -PYTHONPATH=src hatch run test:unit -vvv -``` - -#### Integration tests against a live Oracle instance - -Set `ORACLE_USER`, `ORACLE_PASSWORD`, and `ORACLE_DSN` environment variables to point at your Oracle 23ai instance, then: - -```bash -PYTHONPATH=src hatch run test:integration -vvv -``` - -#### Integration tests via Docker (local Oracle 23ai Free) - -A `docker-compose.yml` is provided that runs [`gvenzl/oracle-free:23-slim`](https://hub.docker.com/r/gvenzl/oracle-free) (Oracle Database 23ai Free edition). - -```bash -docker compose up -d --wait -``` - -`--wait` blocks until the Oracle healthcheck passes (the first boot takes 2–4 minutes while Oracle initialises its data files). - -Run the full integration test suite: - -```bash -PYTHONPATH=src hatch run test:integration -vvv -``` \ No newline at end of file diff --git a/integrations/oracle/pyproject.toml b/integrations/oracle/pyproject.toml index 66678a43eb..1df25cff12 100644 --- a/integrations/oracle/pyproject.toml +++ b/integrations/oracle/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", ] dependencies = [ - "haystack-ai>=2.26.1", + "haystack-ai>=2.28.0", "oracledb>=2.1.0,<3.0.0", ] diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py index 90db6a0719..2cff16e4d6 100644 --- a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/__init__.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 from haystack_integrations.components.retrievers.oracle.embedding_retriever import OracleEmbeddingRetriever +from haystack_integrations.components.retrievers.oracle.keyword_retriever import OracleKeywordRetriever -__all__ = ["OracleEmbeddingRetriever"] +__all__ = ["OracleEmbeddingRetriever", "OracleKeywordRetriever"] diff --git a/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py new file mode 100644 index 0000000000..c4e6bfe185 --- /dev/null +++ b/integrations/oracle/src/haystack_integrations/components/retrievers/oracle/keyword_retriever.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.oracle import OracleDocumentStore + + +@component +class OracleKeywordRetriever: + """ + Retrieves documents from an OracleDocumentStore using keyword-based (BM25) similarity. + + Requires Oracle Database 23ai and an automatically created DBMS_SEARCH index. + + Use inside a Haystack pipeline:: + + pipeline.add_component("retriever", OracleKeywordRetriever(document_store=store, top_k=5)) + """ + + def __init__( + self, + *, + document_store: OracleDocumentStore, + filters: dict[str, Any] | None = None, + top_k: int = 10, + filter_policy: FilterPolicy = FilterPolicy.REPLACE, + ) -> None: + if not isinstance(document_store, OracleDocumentStore): + msg = "document_store must be an instance of OracleDocumentStore" + raise TypeError(msg) + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.filter_policy = FilterPolicy.from_str(filter_policy) if isinstance(filter_policy, str) else filter_policy + + @component.output_types(documents=list[Document]) + def run( + self, + query: str, + filters: dict[str, Any] | None = None, + top_k: int | None = None, + ) -> dict[str, list[Document]]: + """ + Retrieve documents by keyword search. + + Args: + query: The keyword query string. + filters: Runtime filters, merged with constructor filters according to filter_policy. + top_k: Override the constructor top_k for this call. + + Returns: + ``{"documents": [Document, ...]}`` + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + docs = self.document_store._keyword_retrieval( + query, + filters=filters, + top_k=top_k if top_k is not None else self.top_k, + ) + return {"documents": docs} + + @component.output_types(documents=list[Document]) + async def run_async( + self, + query: str, + filters: dict[str, Any] | None = None, + top_k: int | None = None, + ) -> dict[str, list[Document]]: + """Async variant of :meth:`run`.""" + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + docs = await self.document_store._keyword_retrieval_async( + query, + filters=filters, + top_k=top_k if top_k is not None else self.top_k, + ) + return {"documents": docs} + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + document_store=self.document_store.to_dict(), + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "OracleKeywordRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + params = data.get("init_parameters", {}) + if "document_store" in params: + params["document_store"] = OracleDocumentStore.from_dict(params["document_store"]) + if filter_policy := params.get("filter_policy"): + params["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) diff --git a/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py index ab6606af33..19391bab21 100644 --- a/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py +++ b/integrations/oracle/src/haystack_integrations/document_stores/oracle/document_store.py @@ -14,7 +14,7 @@ import oracledb from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document -from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace @@ -23,6 +23,31 @@ logger = logging.getLogger(__name__) _SAFE_TABLE_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_$#]{0,127}$") +_SAFE_FIELD_PATH = re.compile(r"^[A-Za-z0-9_.]+$") +MAX_INDEX_NAME_LEN = 128 + + +def _validate_field_path(field_path: str) -> None: + if not _SAFE_FIELD_PATH.match(field_path): + msg = f"Invalid metadata field name: {field_path!r}" + raise ValueError(msg) + + +def _try_parse_number(value: Any) -> Any: + """ + Attempt to parse a string as a number. + + Returns int for whole numbers, float for decimals, or the + original value when conversion is not possible. + """ + if value is None: + return None + try: + f = float(value) + i = int(f) + return i if f == i else f + except (ValueError, TypeError): + return value @dataclass @@ -112,6 +137,30 @@ def __init__( hnsw_accuracy: int = 95, hnsw_parallel: int = 4, ) -> None: + """ + Initialise the document store and optionally create the backing table and indexes. + + :param connection_config: Oracle connection settings (user, password, DSN, optional wallet). + :param table_name: Name of the Oracle table used to store documents. Must be a valid Oracle + identifier (letters, digits, ``_``, ``$``, ``#``; max 128 chars; cannot start with a digit). + :param embedding_dim: Dimensionality of the embedding vectors. Must match the model producing them. + :param distance_metric: Vector distance function used for similarity search. + One of ``"COSINE"``, ``"EUCLIDEAN"``, or ``"DOT"``. + :param create_table_if_not_exists: When ``True`` (default), creates the table and the DBMS_SEARCH + keyword index on first use if they do not already exist. Set to ``False`` when connecting to a + pre-existing table. + :param create_index: When ``True``, creates an HNSW vector index on initialisation. Equivalent to + calling :meth:`create_hnsw_index` manually. Defaults to ``False``. + :param hnsw_neighbors: Number of neighbours in the HNSW graph. Higher values improve recall at the + cost of index size and build time. Defaults to ``32``. + :param hnsw_ef_construction: Size of the dynamic candidate list during HNSW index construction. + Higher values improve recall at the cost of build time. Defaults to ``200``. + :param hnsw_accuracy: Target recall accuracy percentage for the HNSW index (0-100). + Defaults to ``95``. + :param hnsw_parallel: Degree of parallelism used when building the HNSW index. Defaults to ``4``. + :raises ValueError: If ``table_name`` is not a valid Oracle identifier or ``embedding_dim`` is not + a positive integer. + """ if not _SAFE_TABLE_NAME.match(table_name): msg = ( f"Invalid table_name {table_name!r}. Must be a valid Oracle identifier " @@ -192,6 +241,33 @@ def _ensure_table(self) -> None: cur.execute(sql) conn.commit() + self._ensure_keyword_index() + + def _ensure_keyword_index(self) -> None: + index_name = f"{self.table_name}_search_idx" + if len(index_name) > MAX_INDEX_NAME_LEN: + index_name = index_name[:MAX_INDEX_NAME_LEN] + try: + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute( + f"BEGIN DBMS_SEARCH.CREATE_INDEX('{index_name}'); " + f"DBMS_SEARCH.ADD_SOURCE('{index_name}', '{self.table_name}'); END;" + ) + conn.commit() + except oracledb.DatabaseError as e: + logger.debug("Could not create keyword index (may already exist): %s", e) + + def create_keyword_index(self) -> None: + """ + Create the DBMS_SEARCH keyword index on this table. + + Safe to call multiple times — silently skips if the index already exists. + Required for keyword retrieval. Called automatically when + ``create_table_if_not_exists=True``, but must be called explicitly + when connecting to a pre-existing table. + """ + self._ensure_keyword_index() + def create_hnsw_index(self) -> None: """ Create an HNSW vector index on the embedding column. @@ -213,7 +289,11 @@ def create_hnsw_index(self) -> None: conn.commit() async def create_hnsw_index_async(self) -> None: - """Async variant of create_hnsw_index.""" + """ + Asynchronously creates an HNSW vector index on the embedding column. + + Safe to call multiple times — uses ``IF NOT EXISTS``. + """ await asyncio.to_thread(self.create_hnsw_index) def write_documents( @@ -355,7 +435,7 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume :returns: A list of Documents that match the given filters. """ where, params = OracleDocumentStore._build_where(filters) - sql = f"SELECT id, text, metadata FROM {self.table_name} {where}" + sql = f"SELECT id, text, JSON_SERIALIZE(metadata) AS metadata FROM {self.table_name} {where}" with self._get_connection() as conn, conn.cursor() as cur: cur.execute(sql, params) rows = cur.fetchall() @@ -418,6 +498,366 @@ async def count_documents_async(self) -> int: """ return await asyncio.to_thread(self.count_documents) + def delete_table(self) -> None: + """ + Permanently drops the document store table and its associated DBMS_SEARCH keyword index. + + Uses ``DROP TABLE ... PURGE`` which bypasses the Oracle recycle bin — the operation is + irreversible. The keyword index is dropped after the table; if either operation fails a + :class:`DocumentStoreError` is raised. + + :raises DocumentStoreError: If the table or keyword index cannot be dropped. + """ + with self._get_connection() as conn, conn.cursor() as cur: + sql = f"DROP TABLE {self.table_name} PURGE" + try: + cur.execute(sql) + except oracledb.DatabaseError as e: + logger.debug("Failed to drop table. SQL: %s", sql) + msg = ( + f"Failed to drop table '{self.table_name}'. Error: {e!r}. " + "You can find the SQL query in the debug logs." + ) + raise DocumentStoreError(msg) from e + index_name = f"{self.table_name}_search_idx" + if len(index_name) > MAX_INDEX_NAME_LEN: + index_name = index_name[:MAX_INDEX_NAME_LEN] + sql = f"BEGIN DBMS_SEARCH.DROP_INDEX('{index_name}'); END;" + try: + cur.execute(sql) + except oracledb.DatabaseError as e: + logger.debug("Failed to drop keyword index. SQL: %s", sql) + msg = ( + f"Failed to drop keyword index '{index_name}'. Error: {e!r}. " + "You can find the SQL query in the debug logs." + ) + raise DocumentStoreError(msg) from e + conn.commit() + + async def delete_table_async(self) -> None: + """ + Asynchronously permanently drops the document store table and its DBMS_SEARCH keyword index. + + Uses ``DROP TABLE ... PURGE`` which bypasses the Oracle recycle bin — the operation is + irreversible. + + :raises DocumentStoreError: If the table or keyword index cannot be dropped. + """ + await asyncio.to_thread(self.delete_table) + + def delete_all_documents(self) -> None: + """ + Removes all documents from the table using ``TRUNCATE``. + + ``TRUNCATE`` is non-recoverable — it cannot be rolled back and bypasses row-level triggers. + The table structure and indexes are preserved. + """ + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(f"TRUNCATE TABLE {self.table_name}") + conn.commit() + + async def delete_all_documents_async(self) -> None: + """ + Asynchronously removes all documents from the table using ``TRUNCATE``. + + ``TRUNCATE`` is non-recoverable — it cannot be rolled back and bypasses row-level triggers. + The table structure and indexes are preserved. + """ + await asyncio.to_thread(self.delete_all_documents) + + def count_documents_by_filter(self, filters: dict[str, Any]) -> int: + """ + Returns the number of documents that match the provided filters. + + :param filters: Haystack filter dict. An empty dict matches all documents. + See the `metadata filtering docs `_. + :returns: Count of matching documents. + """ + where, params = OracleDocumentStore._build_where(filters) + sql = f"SELECT COUNT(*) FROM {self.table_name} {where}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + row = cur.fetchone() + return row[0] if row else 0 + + async def count_documents_by_filter_async(self, filters: dict[str, Any]) -> int: + """ + Asynchronously returns the number of documents that match the provided filters. + + :param filters: Haystack filter dict. An empty dict matches all documents. + See the `metadata filtering docs `_. + :returns: Count of matching documents. + """ + return await asyncio.to_thread(self.count_documents_by_filter, filters) + + def delete_by_filter(self, filters: dict[str, Any]) -> int: + """ + Deletes all documents that match the provided filters. + + :param filters: Haystack filter dict. An empty dict is treated as a no-op and returns ``0`` + without touching the table. + See the `metadata filtering docs `_. + :returns: Number of deleted documents. + """ + if not filters: + return 0 + where, params = OracleDocumentStore._build_where(filters) + sql = f"DELETE FROM {self.table_name} {where}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + deleted = cur.rowcount + conn.commit() + return deleted + + async def delete_by_filter_async(self, filters: dict[str, Any]) -> int: + """ + Asynchronously deletes all documents that match the provided filters. + + :param filters: Haystack filter dict. An empty dict is treated as a no-op and returns ``0`` + without touching the table. + See the `metadata filtering docs `_. + :returns: Number of deleted documents. + """ + return await asyncio.to_thread(self.delete_by_filter, filters) + + def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: + """ + Merges ``meta`` into the metadata of all documents that match the provided filters. + + Uses Oracle's ``JSON_MERGEPATCH`` — existing keys are updated, new keys are added, + and keys set to ``null`` in ``meta`` are removed. + + :param filters: Haystack filter dict that selects which documents to update. + See the `metadata filtering docs `_. + :param meta: Metadata patch to apply. Must be a non-empty dictionary. + :returns: Number of updated documents. + :raises ValueError: If ``meta`` is empty. + """ + if not meta: + msg = "meta must be a non-empty dictionary" + raise ValueError(msg) + where, params = OracleDocumentStore._build_where(filters) + sql = f"UPDATE {self.table_name} SET metadata = JSON_MERGEPATCH(metadata, :meta_patch) {where}" + params["meta_patch"] = json.dumps(meta) + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql, params) + updated = cur.rowcount + conn.commit() + return updated + + async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: + """ + Asynchronously merges ``meta`` into the metadata of all documents matching the provided filters. + + Uses Oracle's ``JSON_MERGEPATCH`` — existing keys are updated, new keys are added, + and keys set to ``null`` in ``meta`` are removed. + + :param filters: Haystack filter dict that selects which documents to update. + See the `metadata filtering docs `_. + :param meta: Metadata patch to apply. Must be a non-empty dictionary. + :returns: Number of updated documents. + :raises ValueError: If ``meta`` is empty. + """ + return await asyncio.to_thread(self.update_by_filter, filters, meta) + + def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]: + """ + Returns the number of distinct values for each requested metadata field among matching documents. + + :param filters: Haystack filter dict that scopes the document set. + See the `metadata filtering docs `_. + :param metadata_fields: List of metadata field names to count distinct values for. + Fields may be prefixed with ``"meta."`` (e.g. ``"meta.lang"`` or ``"lang"``). + Must be a non-empty list. + :returns: Dict mapping each field name to its distinct-value count. + :raises ValueError: If ``metadata_fields`` is empty. + :raises ValueError: If any field name contains characters outside ``[A-Za-z0-9_.]``. + """ + if not metadata_fields: + msg = "metadata_fields must be a non-empty list of strings" + raise ValueError(msg) + where, params = OracleDocumentStore._build_where(filters) + results = {} + with self._get_connection() as conn, conn.cursor() as cur: + for field in metadata_fields: + field_path = field[5:] if field.startswith("meta.") else field + _validate_field_path(field_path) + sql = f"SELECT COUNT(DISTINCT JSON_VALUE(metadata, '$.{field_path}')) FROM {self.table_name} {where}" + cur.execute(sql, params) + row = cur.fetchone() + results[field] = row[0] if row else 0 + return results + + async def count_unique_metadata_by_filter_async( + self, filters: dict[str, Any], metadata_fields: list[str] + ) -> dict[str, int]: + """ + Asynchronously returns the number of distinct values for each metadata field among matching documents. + + :param filters: Haystack filter dict that scopes the document set. + See the `metadata filtering docs `_. + :param metadata_fields: List of metadata field names to count distinct values for. + Fields may be prefixed with ``"meta."`` (e.g. ``"meta.lang"`` or ``"lang"``). + Must be a non-empty list. + :returns: Dict mapping each field name to its distinct-value count. + :raises ValueError: If ``metadata_fields`` is empty. + :raises ValueError: If any field name contains characters outside ``[A-Za-z0-9_.]``. + """ + return await asyncio.to_thread(self.count_unique_metadata_by_filter, filters, metadata_fields) + + def get_metadata_fields_info(self) -> dict[str, dict[str, str]]: + """ + Return a mapping of metadata field names to their detected types. + + Uses Oracle's ``JSON_DATAGUIDE`` aggregate to introspect the stored metadata column. + Returns an empty dict when the table has no documents. + + :returns: Dict of the form ``{"field_name": {"type": ""}, ...}`` where ```` + is one of ``"text"``, ``"number"``, or ``"boolean"``. + """ + sql = f"SELECT JSON_DATAGUIDE(metadata) FROM {self.table_name}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql) + row = cur.fetchone() + if not row or not row[0]: + return {} + raw_guide = row[0].read() if hasattr(row[0], "read") else row[0] + if not raw_guide: + return {} + fields: dict[str, dict[str, str]] = {} + dataguide = json.loads(raw_guide) + for path_info in dataguide: + path = path_info.get("o:path", "") + if path.startswith("$."): + field_name = path[2:] + type_str = path_info.get("type", "string") + if type_str == "string": + type_str = "text" + fields[field_name] = {"type": type_str} + return fields + + def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]: + """ + Return the minimum and maximum values of a metadata field across all documents. + + First attempts numeric comparison via ``TO_NUMBER`` so that ``MAX(1, 5, 10)`` returns ``10`` + rather than ``"5"`` (which would win under lexicographic ordering). Falls back to plain string + comparison when the field contains non-numeric values. Numeric strings are automatically + converted to ``int`` or ``float`` in the result. + + :param metadata_field: Metadata field name. May be prefixed with ``"meta."`` + (e.g. ``"meta.year"`` or ``"year"``). + :returns: ``{"min": , "max": }``. Both values are ``None`` when the table is + empty or the field does not exist. + :raises ValueError: If ``metadata_field`` contains characters outside ``[A-Za-z0-9_.]``. + """ + field_path = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field + _validate_field_path(field_path) + jv = f"JSON_VALUE(metadata, '$.{field_path}')" + # Try numeric comparison first — correct ordering for ints/floats. + sql_num = f"SELECT MIN(TO_NUMBER({jv})), MAX(TO_NUMBER({jv})) FROM {self.table_name} WHERE {jv} IS NOT NULL" + with self._get_connection() as conn, conn.cursor() as cur: + try: + cur.execute(sql_num) + row = cur.fetchone() + if row and row[0] is not None: + return {"min": _try_parse_number(row[0]), "max": _try_parse_number(row[1])} + except oracledb.DatabaseError: + pass + # Fall back to string comparison for non-numeric fields. + sql_str = f"SELECT MIN({jv}), MAX({jv}) FROM {self.table_name}" + cur.execute(sql_str) + row = cur.fetchone() + if not row or row[0] is None or row[1] is None: + return {"min": None, "max": None} + return {"min": _try_parse_number(row[0]), "max": _try_parse_number(row[1])} + + def get_metadata_field_unique_values( + self, metadata_field: str, search_term: str | None = None, from_: int = 0, size: int | None = None + ) -> tuple[list[str], int]: + """ + Return a paginated list of distinct values for a metadata field, plus the total distinct count. + + :param metadata_field: Metadata field name. May be prefixed with ``"meta."`` + (e.g. ``"meta.lang"`` or ``"lang"``). + :param search_term: Optional substring filter applied to both the document text and the field value. + :param from_: Zero-based offset for pagination. Defaults to ``0``. + :param size: Maximum number of values to return. When ``None`` all values from ``from_`` onward + are returned. + :returns: A tuple ``(values, total)`` where ``values`` is the paginated list of distinct field + values as strings and ``total`` is the overall distinct count (before pagination). + :raises ValueError: If ``metadata_field`` contains characters outside ``[A-Za-z0-9_.]``. + """ + field_path = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field + _validate_field_path(field_path) + base_sql = f"FROM {self.table_name} WHERE JSON_VALUE(metadata, '$.{field_path}') IS NOT NULL" + params: dict[str, Any] = {} + if search_term: + base_sql += f" AND (text LIKE :search OR JSON_VALUE(metadata, '$.{field_path}') LIKE :search)" + params["search"] = f"%{search_term}%" + + sql_count = f"SELECT COUNT(DISTINCT JSON_VALUE(metadata, '$.{field_path}')) {base_sql}" + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute(sql_count, params) + total = cur.fetchone()[0] or 0 + + sql_vals = f"SELECT DISTINCT JSON_VALUE(metadata, '$.{field_path}') {base_sql} ORDER BY 1" + if size is not None: + sql_vals += " OFFSET :row_offset ROWS FETCH NEXT :row_limit ROWS ONLY" + params["row_offset"] = from_ + params["row_limit"] = size + else: + sql_vals += " OFFSET :row_offset ROWS" + params["row_offset"] = from_ + cur.execute(sql_vals, params) + rows = cur.fetchall() + return [str(r[0]) for r in rows], total + + async def get_metadata_fields_info_async(self) -> dict[str, dict[str, str]]: + """ + Asynchronously returns a mapping of metadata field names to their detected types. + + Uses Oracle's ``JSON_DATAGUIDE`` aggregate to introspect the stored metadata column. + Returns an empty dict when the table has no documents. + + :returns: Dict of the form ``{"field_name": {"type": ""}, ...}`` where ```` + is one of ``"text"``, ``"number"``, or ``"boolean"``. + """ + return await asyncio.to_thread(self.get_metadata_fields_info) + + async def get_metadata_field_min_max_async(self, metadata_field: str) -> dict[str, Any]: + """ + Asynchronously returns the minimum and maximum values of a metadata field across all documents. + + First attempts numeric comparison via ``TO_NUMBER``, falling back to string comparison for + non-numeric fields. Numeric strings are automatically converted to ``int`` or ``float``. + + :param metadata_field: Metadata field name. May be prefixed with ``"meta."`` + (e.g. ``"meta.year"`` or ``"year"``). + :returns: ``{"min": , "max": }``. Both values are ``None`` when the table is + empty or the field does not exist. + :raises ValueError: If ``metadata_field`` contains characters outside ``[A-Za-z0-9_.]``. + """ + return await asyncio.to_thread(self.get_metadata_field_min_max, metadata_field) + + async def get_metadata_field_unique_values_async( + self, metadata_field: str, search_term: str | None = None, from_: int = 0, size: int | None = None + ) -> tuple[list[str], int]: + """ + Asynchronously returns a paginated list of distinct values for a metadata field, plus the total count. + + :param metadata_field: Metadata field name. May be prefixed with ``"meta."`` + (e.g. ``"meta.lang"`` or ``"lang"``). + :param search_term: Optional substring filter applied to both the document text and the field value. + :param from_: Zero-based offset for pagination. Defaults to ``0``. + :param size: Maximum number of values to return. When ``None`` all values from ``from_`` onward + are returned. + :returns: A tuple ``(values, total)`` where ``values`` is the paginated list of distinct field + values as strings and ``total`` is the overall distinct count (before pagination). + :raises ValueError: If ``metadata_field`` contains characters outside ``[A-Za-z0-9_.]``. + """ + return await asyncio.to_thread(self.get_metadata_field_unique_values, metadata_field, search_term, from_, size) + def _embedding_retrieval( self, query_embedding: list[float], @@ -430,7 +870,7 @@ def _embedding_retrieval( order = "ASC" where, params = OracleDocumentStore._build_where(filters) sql = f""" - SELECT id, text, metadata, + SELECT id, text, JSON_SERIALIZE(metadata) AS metadata, vector_distance(embedding, :query_vec, {self.distance_metric}) AS score FROM {self.table_name} {where} @@ -440,7 +880,15 @@ def _embedding_retrieval( params["query_vec"] = _array.array("f", query_embedding) params["top_k"] = top_k with self._get_connection() as conn, conn.cursor() as cur: - cur.execute(sql, params) + try: + cur.execute(sql, params) + except oracledb.DatabaseError as e: + logger.debug("Embedding retrieval failed. SQL: %s\nParams: %s", sql, params) + msg = ( + f"Embedding retrieval failed. Error: {e!r}. " + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(msg) from e rows = cur.fetchall() return [OracleDocumentStore._row_to_document(r, with_score=True) for r in rows] @@ -458,6 +906,48 @@ async def _embedding_retrieval_async( top_k=top_k, ) + def _keyword_retrieval( + self, query: str, *, filters: dict[str, Any] | None = None, top_k: int = 10 + ) -> list[Document]: + index_name = f"{self.table_name}_search_idx" + if len(index_name) > MAX_INDEX_NAME_LEN: + index_name = index_name[:MAX_INDEX_NAME_LEN] + where, params = OracleDocumentStore._build_where(filters) + where_cond = where.replace("WHERE", "WHERE t.") if where else "" + sql = f""" + WITH hits AS ( + SELECT JSON_VALUE(METADATA, '$.KEY.ID') AS hit_id, SCORE(1) AS score + FROM {index_name} + WHERE CONTAINS(DATA, :query, 1) > 0 + ORDER BY score DESC + FETCH APPROX FIRST :top_k ROWS ONLY + ) + SELECT t.id, t.text, JSON_SERIALIZE(t.metadata) AS metadata, hits.score + FROM hits + JOIN {self.table_name} t ON t.id = hits.hit_id + {where_cond} + ORDER BY hits.score DESC + """ + params["query"] = query + params["top_k"] = top_k + with self._get_connection() as conn, conn.cursor() as cur: + try: + cur.execute(sql, params) + except oracledb.DatabaseError as e: + logger.debug("Keyword retrieval failed. SQL: %s\nParams: %s", sql, params) + msg = ( + f"Keyword retrieval failed. Error: {e!r}. " + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(msg) from e + rows = cur.fetchall() + return [OracleDocumentStore._row_to_document(r, with_score=True) for r in rows] + + async def _keyword_retrieval_async( + self, query: str, *, filters: dict[str, Any] | None = None, top_k: int = 10 + ) -> list[Document]: + return await asyncio.to_thread(self._keyword_retrieval, query, filters=filters, top_k=top_k) + @staticmethod def _row_to_document(row: tuple, *, with_score: bool = False) -> Document: if with_score: diff --git a/integrations/oracle/tests/conftest.py b/integrations/oracle/tests/conftest.py index 5d339248fc..aa85d651ab 100644 --- a/integrations/oracle/tests/conftest.py +++ b/integrations/oracle/tests/conftest.py @@ -100,6 +100,8 @@ def mock_store(): store.distance_metric = "COSINE" store._embedding_retrieval.return_value = [Document(id="A" * 32, content="hi")] store._embedding_retrieval_async.return_value = [Document(id="A" * 32, content="hi")] + store._keyword_retrieval.return_value = [Document(id="A" * 32, content="hi")] + store._keyword_retrieval_async.return_value = [Document(id="A" * 32, content="hi")] store.to_dict.return_value = { "type": "haystack_integrations.document_stores.oracle.document_store.OracleDocumentStore", "init_parameters": { diff --git a/integrations/oracle/tests/test_document_store.py b/integrations/oracle/tests/test_document_store.py index 6960699f35..fe855b7a5e 100644 --- a/integrations/oracle/tests/test_document_store.py +++ b/integrations/oracle/tests/test_document_store.py @@ -9,7 +9,24 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import ( + CountDocumentsByFilterTest, + CountUniqueMetadataByFilterTest, + DeleteAllTest, + DeleteByFilterTest, + DeleteDocumentsTest, + DocumentStoreBaseTests, + GetMetadataFieldMinMaxTest, + GetMetadataFieldsInfoTest, + GetMetadataFieldUniqueValuesTest, + UpdateByFilterTest, +) +from haystack.testing.document_store_async import ( + CountDocumentsByFilterAsyncTest, + CountUniqueMetadataByFilterAsyncTest, + FilterableDocsFixtureMixin, + UpdateByFilterAsyncTest, +) from haystack.utils import Secret from haystack_integrations.document_stores.oracle import OracleConnectionConfig, OracleDocumentStore @@ -31,7 +48,18 @@ def _uid(suffix: str = "") -> str: @pytest.mark.integration -class TestOracleDocumentStore(DocumentStoreBaseTests): +class TestOracleDocumentStore( + DocumentStoreBaseTests, + CountDocumentsByFilterTest, + CountUniqueMetadataByFilterTest, + DeleteAllTest, + DeleteByFilterTest, + DeleteDocumentsTest, + GetMetadataFieldMinMaxTest, + GetMetadataFieldsInfoTest, + GetMetadataFieldUniqueValuesTest, + UpdateByFilterTest, +): @staticmethod def _mock_doc(content="hello", embedding=None, doc_id="AABB" * 8): """Lightweight document builder for mock-based tests.""" @@ -285,7 +313,12 @@ def test_create_table_idempotent(self, document_store): @pytest.mark.integration -class TestOracleDocumentStoreAsync: +class TestOracleDocumentStoreAsync( + FilterableDocsFixtureMixin, + CountDocumentsByFilterAsyncTest, + CountUniqueMetadataByFilterAsyncTest, + UpdateByFilterAsyncTest, +): """Async API surface tests.""" @pytest.mark.asyncio diff --git a/integrations/oracle/tests/test_keyword_retriever.py b/integrations/oracle/tests/test_keyword_retriever.py new file mode 100644 index 0000000000..97c64d7f5f --- /dev/null +++ b/integrations/oracle/tests/test_keyword_retriever.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from haystack.document_stores.types import FilterPolicy + +from haystack_integrations.components.retrievers.oracle import OracleKeywordRetriever + + +def test_run_calls_keyword_retrieval(mock_store): + retriever = OracleKeywordRetriever(document_store=mock_store, top_k=5) + result = retriever.run(query="hello world") + mock_store._keyword_retrieval.assert_called_once_with("hello world", filters={}, top_k=5) + assert len(result["documents"]) == 1 + + +def test_run_replace_policy_uses_runtime_filters(mock_store): + retriever = OracleKeywordRetriever( + document_store=mock_store, + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + filter_policy=FilterPolicy.REPLACE, + ) + runtime_filters = {"field": "meta.year", "operator": ">", "value": 2020} + retriever.run(query="hello world", filters=runtime_filters) + call_filters = mock_store._keyword_retrieval.call_args.kwargs["filters"] + assert call_filters == runtime_filters + + +def test_run_merge_policy_combines_filters(mock_store): + retriever = OracleKeywordRetriever( + document_store=mock_store, + filters={"field": "meta.lang", "operator": "==", "value": "en"}, + filter_policy=FilterPolicy.MERGE, + ) + retriever.run( + query="hello world", + filters={"field": "meta.year", "operator": ">", "value": 2020}, + ) + call_filters = mock_store._keyword_retrieval.call_args.kwargs["filters"] + assert call_filters["operator"] == "AND" + assert len(call_filters["conditions"]) == 2 + + +def test_run_top_k_override(mock_store): + retriever = OracleKeywordRetriever(document_store=mock_store, top_k=10) + retriever.run(query="hello world", top_k=3) + assert mock_store._keyword_retrieval.call_args.kwargs["top_k"] == 3 + + +def test_to_dict_from_dict_roundtrip(mock_store): + retriever = OracleKeywordRetriever( + document_store=mock_store, + top_k=7, + filters={"field": "meta.x", "operator": "==", "value": "y"}, + ) + d = retriever.to_dict() + assert d["init_parameters"]["top_k"] == 7 + assert d["init_parameters"]["filters"] == {"field": "meta.x", "operator": "==", "value": "y"} + assert d["init_parameters"]["filter_policy"] == "replace" + + restored = OracleKeywordRetriever.from_dict(d) + assert restored.top_k == 7 + assert restored.filters == {"field": "meta.x", "operator": "==", "value": "y"} + assert restored.filter_policy == FilterPolicy.REPLACE + assert restored.document_store.table_name == "test_docs" + assert restored.document_store.embedding_dim == 4 + + +def test_invalid_document_store_raises_type_error(): + with pytest.raises(TypeError, match="must be an instance of OracleDocumentStore"): + OracleKeywordRetriever(document_store="not_a_store") + + +@pytest.mark.asyncio +async def test_run_async_calls_async_retrieval(mock_store): + retriever = OracleKeywordRetriever(document_store=mock_store, top_k=5) + result = await retriever.run_async(query="hello world") + mock_store._keyword_retrieval_async.assert_called_once_with("hello world", filters={}, top_k=5) + assert "documents" in result