diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..39ac8cb --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +extend-ignore = + # Handled by black + E501 diff --git a/client/neuralake/core/tables/deltalake_table.py b/client/neuralake/core/tables/deltalake_table.py index ef86b8e..48b3bee 100644 --- a/client/neuralake/core/tables/deltalake_table.py +++ b/client/neuralake/core/tables/deltalake_table.py @@ -12,6 +12,7 @@ from deltalake.warnings import ExperimentalWarning import polars as pl import pyarrow as pa +from pypika import Query from neuralake.core.dataframe import NlkDataFrame from neuralake.core.tables.filters import InputFilters, normalize_filters @@ -23,6 +24,7 @@ from neuralake.core.tables.util import ( DeltaRoapiOptions, Filter, + RawCriterion, RoapiOptions, filters_to_sql_predicate, get_storage_options, @@ -159,8 +161,6 @@ def construct_df( # Use schema defined on this table, the physical schema in deltalake metadata might be different schema = self.schema - predicate_str = datafusion_predicate_from_filters(schema, filters) - # These should not be read because they don't exist in the delta table extra_col_exprs = [expr for expr, _ in self.extra_cols] extra_column_names = set(expr.meta.output_name() for expr in extra_col_exprs) @@ -172,18 +172,20 @@ def construct_df( (set(columns) | unique_column_names) - extra_column_names ) - # TODO(peter): consider a sql builder for more complex queries? - select_cols = ( - ", ".join([f'"{col}"' for col in columns_to_read]) - if columns_to_read - else "*" - ) - condition = f"WHERE {predicate_str}" if predicate_str else "" - query_string = f""" - SELECT {select_cols} - FROM "{self.name}" - {condition} - """ + select_cols = columns_to_read or ["*"] + + query = Query.from_(self.name).select(*select_cols) + + if filters: + if isinstance(filters, str): + criterion = RawCriterion(filters) + else: + normalized_filters = normalize_filters(filters) + criterion = filters_to_sql_predicate(schema, normalized_filters) + query = query.where(criterion) + + query_string = str(query) + with warnings.catch_warnings(): # Ignore ExperimentalWarning emitted from QueryBuilder warnings.filterwarnings("ignore", category=ExperimentalWarning) @@ -324,15 +326,3 @@ def _normalize_df( .cast(polars_schema) .select(schema_columns) ) - - -def datafusion_predicate_from_filters( - schema: pa.Schema, filters: DeltaInputFilters | None -) -> str | None: - if not filters: - return None - elif isinstance(filters, str): - return filters - - normalized_filters = normalize_filters(filters) - return filters_to_sql_predicate(schema, normalized_filters) diff --git a/client/neuralake/core/tables/util.py b/client/neuralake/core/tables/util.py index d40ce95..ba5af1c 100644 --- a/client/neuralake/core/tables/util.py +++ b/client/neuralake/core/tables/util.py @@ -5,6 +5,7 @@ import boto3 import polars as pl import pyarrow as pa +from pypika import Field, CustomFunction, Criterion from neuralake.core.tables.filters import Filter, NormalizedFilters @@ -108,48 +109,52 @@ def get_pyarrow_filesystem_args( return pyarrow_filesystem_args -def filters_to_sql_predicate(schema: pa.Schema, filters: NormalizedFilters) -> str: - if not filters: - return "true" +class RawCriterion(Criterion): + def __init__(self, expr: str) -> None: + super().__init__() + self.expr = expr - return " or ".join( + def get_sql(self, **kwargs: Any) -> str: + return self.expr + + +def filters_to_sql_predicate( + schema: pa.Schema, filters: NormalizedFilters +) -> Criterion: + return Criterion.any( filters_to_sql_conjunction(schema, filter_set) for filter_set in filters ) -def filters_to_sql_conjunction(schema: pa.Schema, filters: list[Filter]) -> str: - if not filters: - return "true" - - exprs = (filter_to_sql_expr(schema, f) for f in filters) - conjunction_expr = " and ".join(exprs) - return f"({conjunction_expr})" +def filters_to_sql_conjunction(schema: pa.Schema, filters: list[Filter]) -> Criterion: + return Criterion.all(filter_to_sql_expr(schema, f) for f in filters) -def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> str: +def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> Criterion: column = f.column if column not in schema.names: raise ValueError(f"Invalid column name {column}") column_type = schema.field(column).type - if f.operator in ( - "=", - "!=", - "<", - "<=", - ">", - ">=", - "in", - "not in", - ): - value_str = value_to_sql_expr(f.value, column_type) - return f"({column} {f.operator} {value_str})" - + if f.operator == "=": + return Field(column) == f.value + elif f.operator == "!=": + return Field(column) != f.value + elif f.operator == "<": + return Field(column) < f.value + elif f.operator == "<=": + return Field(column) <= f.value + elif f.operator == ">": + return Field(column) > f.value + elif f.operator == ">=": + return Field(column) >= f.value + elif f.operator == "in": + return Field(column).isin(f.value) + elif f.operator == "not in": + return Field(column).notin(f.value) elif f.operator == "contains": assert isinstance(f.value, str) - escaped_str = escape_str_for_sql(f.value) - like_str = f"'%{escaped_str}%'" - return f"({column} like {like_str})" + return Field(column).like(f"%{f.value}%") elif f.operator in ("includes", "includes any", "includes all"): assert pa.types.is_list(column_type) or pa.types.is_large_list(column_type) @@ -160,37 +165,18 @@ def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> str: else: assert isinstance(f.value, list | tuple) values = list(f.value) + assert values # NOTE: for includes any/all, we join multiple array_contains with or/and - value_exprs = ( - value_to_sql_expr(value, column_type.value_type) for value in values - ) - include_exprs = ( - f"array_contains({column}, {value_expr})" for value_expr in value_exprs + array_contains = CustomFunction("array_contains", ["table", "value"]) + include_exprs = [array_contains(Field(column), value) for value in values] + final_expr = ( + Criterion.all(include_exprs) + if f.operator == "includes all" + else Criterion.any(include_exprs) ) - join_operator = " or " if f.operator == "includes any" else " and " - conjunction_expr = join_operator.join(include_exprs) - return f"({conjunction_expr})" + return final_expr else: raise ValueError(f"Invalid operator {f.operator}") - - -def value_to_sql_expr(value: Any, value_type: pa.DataType) -> str: - if isinstance(value, list | tuple): - elements_str = ", ".join( - value_to_sql_expr(element, value_type) for element in value - ) - value_str = f"({elements_str})" - else: - value_str = str(value) - # Escape the string so the user doesn't need to filter like ("col", "=", "'value'") - if pa.types.is_string(value_type): - escaped_str = escape_str_for_sql(value_str) - value_str = f"'{escaped_str}'" - return value_str - - -def escape_str_for_sql(value: str) -> str: - return value.replace("'", "''") diff --git a/client/neuralake/test/tables/test_deltalake_table.py b/client/neuralake/test/tables/test_deltalake_table.py index 88ec485..238a369 100644 --- a/client/neuralake/test/tables/test_deltalake_table.py +++ b/client/neuralake/test/tables/test_deltalake_table.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -9,7 +8,6 @@ import pytest from neuralake.core.tables.deltalake_table import ( - DeltaCacheOptions, DeltalakeTable, Filter, fetch_df_by_partition, diff --git a/client/neuralake/test/tables/test_util.py b/client/neuralake/test/tables/test_util.py index 5e027cb..0e7c013 100644 --- a/client/neuralake/test/tables/test_util.py +++ b/client/neuralake/test/tables/test_util.py @@ -1,3 +1,5 @@ +from datetime import datetime, timezone + import pyarrow as pa import pytest @@ -18,6 +20,7 @@ ("int_col", pa.int64()), ("list_col", pa.list_(pa.int64())), ("list_str_col", pa.list_(pa.string())), + ("date_col", pa.timestamp("us", tz="UTC")), ] ) @@ -26,54 +29,64 @@ class TestUtil: @pytest.mark.parametrize( ("schema", "f", "expected"), [ - (test_schema, Filter("int_col", "=", 123), "(int_col = 123)"), - (test_schema, Filter("int_col", "=", "123"), "(int_col = 123)"), + (test_schema, Filter("int_col", "=", 123), '"int_col"=123'), + (test_schema, Filter("int_col", "=", "123"), "\"int_col\"='123'"), # A tuple with a single element should not have a comma in SQL - (test_schema, Filter("int_col", "in", (1,)), "(int_col in (1))"), - (test_schema, Filter("int_col", "in", (1, 2)), "(int_col in (1, 2))"), + (test_schema, Filter("int_col", "in", (1,)), '"int_col" IN (1)'), + (test_schema, Filter("int_col", "in", (1, 2)), '"int_col" IN (1,2)'), ( test_schema, Filter("int_col", "not in", (1, 2)), - "(int_col not in (1, 2))", + '"int_col" NOT IN (1,2)', ), # String columns should be handled to add single quotes - (test_schema, Filter("str_col", "=", "x"), "(str_col = 'x')"), - (test_schema, Filter("str_col", "in", ("val1",)), "(str_col in ('val1'))"), + (test_schema, Filter("str_col", "=", "x"), "\"str_col\"='x'"), + # Filtering using datetime columns + ( + test_schema, + Filter("date_col", ">=", datetime(2024, 4, 5, tzinfo=timezone.utc)), + "\"date_col\">='2024-04-05T00:00:00+00:00'", + ), + ( + test_schema, + Filter("str_col", "in", ("val1",)), + "\"str_col\" IN ('val1')", + ), ( test_schema, Filter("str_col", "in", ("val1", "val2")), - "(str_col in ('val1', 'val2'))", + "\"str_col\" IN ('val1','val2')", ), ( test_schema, Filter("str_col", "contains", "x'"), - "(str_col like '%x''%')", + "\"str_col\" LIKE '%x''%'", ), # Test list columns ( test_schema, Filter("list_col", "includes", 1), - "(array_contains(list_col, 1))", + 'array_contains("list_col",1)', ), ( test_schema, Filter("list_str_col", "includes", "x"), - "(array_contains(list_str_col, 'x'))", + "array_contains(\"list_str_col\",'x')", ), ( test_schema, Filter("list_col", "includes all", (1, 2, 3)), - "(array_contains(list_col, 1) and array_contains(list_col, 2) and array_contains(list_col, 3))", + 'array_contains("list_col",1) AND array_contains("list_col",2) AND array_contains("list_col",3)', ), ( test_schema, Filter("list_col", "includes any", (1, 2, 3)), - "(array_contains(list_col, 1) or array_contains(list_col, 2) or array_contains(list_col, 3))", + 'array_contains("list_col",1) OR array_contains("list_col",2) OR array_contains("list_col",3)', ), ], ) def test_filter_to_expr(self, schema: pa.Schema, f: Filter, expected: str): - assert filter_to_sql_expr(schema, f) == expected + assert str(filter_to_sql_expr(schema, f)) == expected def test_filter_to_expr_raises(self): with pytest.raises(ValueError) as e: @@ -84,11 +97,11 @@ def test_filter_to_expr_raises(self): @pytest.mark.parametrize( ("schema", "filters", "expected"), [ - (test_schema, [[Filter("str_col", "=", "x")]], "((str_col = 'x'))"), + (test_schema, [[Filter("str_col", "=", "x")]], "\"str_col\"='x'"), ( test_schema, [[Filter("str_col", "=", "x"), Filter("int_col", "=", 123)]], - "((str_col = 'x') and (int_col = 123))", + '"str_col"=\'x\' AND "int_col"=123', ), ( test_schema, @@ -96,14 +109,26 @@ def test_filter_to_expr_raises(self): [Filter("str_col", "=", "x")], [Filter("int_col", "=", 123), Filter("int_col", "<", 456)], ], - "((str_col = 'x')) or ((int_col = 123) and (int_col < 456))", + '"str_col"=\'x\' OR ("int_col"=123 AND "int_col"<456)', + ), + ( + test_schema, + # Ensures filters are properly nested and grouped + [ + [Filter("str_col", "=", "x")], + [ + Filter("list_col", "includes any", (1, 2, 3)), + Filter("list_col", "includes all", (1, 2, 3)), + ], + ], + '"str_col"=\'x\' OR ((array_contains("list_col",1) OR array_contains("list_col",2) OR array_contains("list_col",3)) AND array_contains("list_col",1) AND array_contains("list_col",2) AND array_contains("list_col",3))', ), ], ) def test_filters_to_sql_predicate( self, schema: pa.Schema, filters: NormalizedFilters, expected: str ): - assert filters_to_sql_predicate(schema, filters) == expected + assert str(filters_to_sql_predicate(schema, filters)) == expected @pytest.mark.parametrize( ("filters", "expected"), diff --git a/client/neuralake/web_export.py b/client/neuralake/web_export.py index 3302108..97c8e6b 100644 --- a/client/neuralake/web_export.py +++ b/client/neuralake/web_export.py @@ -2,14 +2,10 @@ from neuralake.core.tables.deltalake_table import DeltalakeTable from neuralake.core.tables.metadata import TableProtocol import json -import tempfile -import subprocess from pathlib import Path -import os import sys import logging import shutil -import importlib.resources # Configure logging logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr) @@ -89,7 +85,7 @@ def export_and_generate_site( logger.info(f"Copying precompiled directory {precompiled_dir} to {output_path}") if not precompiled_dir.exists(): raise FileNotFoundError( - f"Could not find precompiled directory. Make sure you're running from the project root or the package is properly installed." + "Could not find precompiled directory. Make sure you're running from the project root or the package is properly installed." ) shutil.copytree(precompiled_dir, output_path, dirs_exist_ok=True) diff --git a/client/pyproject.toml b/client/pyproject.toml index 2c72574..ff48016 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pyarrow==17.0.0", "ipython==8.5.0", "typing_extensions==4.13.2", + "pypika>=0.48.0", ] [project.optional-dependencies]