Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
extend-ignore =
# Handled by black
E501
42 changes: 16 additions & 26 deletions client/neuralake/core/tables/deltalake_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deltalake.warnings import ExperimentalWarning
import polars as pl
import pyarrow as pa
from pypika import Query

from neuralake.core.dataframe import NlkDataFrame
from neuralake.core.tables.filters import InputFilters, normalize_filters
Expand All @@ -23,6 +24,7 @@
from neuralake.core.tables.util import (
DeltaRoapiOptions,
Filter,
RawCriterion,
RoapiOptions,
filters_to_sql_predicate,
get_storage_options,
Expand Down Expand Up @@ -159,8 +161,6 @@ def construct_df(
# Use schema defined on this table, the physical schema in deltalake metadata might be different
schema = self.schema

predicate_str = datafusion_predicate_from_filters(schema, filters)

# These should not be read because they don't exist in the delta table
extra_col_exprs = [expr for expr, _ in self.extra_cols]
extra_column_names = set(expr.meta.output_name() for expr in extra_col_exprs)
Expand All @@ -172,18 +172,20 @@ def construct_df(
(set(columns) | unique_column_names) - extra_column_names
)

# TODO(peter): consider a sql builder for more complex queries?
select_cols = (
", ".join([f'"{col}"' for col in columns_to_read])
if columns_to_read
else "*"
)
condition = f"WHERE {predicate_str}" if predicate_str else ""
query_string = f"""
SELECT {select_cols}
FROM "{self.name}"
{condition}
"""
select_cols = columns_to_read or ["*"]

query = Query.from_(self.name).select(*select_cols)

if filters:
if isinstance(filters, str):
criterion = RawCriterion(filters)
else:
normalized_filters = normalize_filters(filters)
criterion = filters_to_sql_predicate(schema, normalized_filters)
query = query.where(criterion)

query_string = str(query)

with warnings.catch_warnings():
# Ignore ExperimentalWarning emitted from QueryBuilder
warnings.filterwarnings("ignore", category=ExperimentalWarning)
Expand Down Expand Up @@ -324,15 +326,3 @@ def _normalize_df(
.cast(polars_schema)
.select(schema_columns)
)


def datafusion_predicate_from_filters(
schema: pa.Schema, filters: DeltaInputFilters | None
) -> str | None:
if not filters:
return None
elif isinstance(filters, str):
return filters

normalized_filters = normalize_filters(filters)
return filters_to_sql_predicate(schema, normalized_filters)
96 changes: 41 additions & 55 deletions client/neuralake/core/tables/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import boto3
import polars as pl
import pyarrow as pa
from pypika import Field, CustomFunction, Criterion

from neuralake.core.tables.filters import Filter, NormalizedFilters

Expand Down Expand Up @@ -108,48 +109,52 @@ def get_pyarrow_filesystem_args(
return pyarrow_filesystem_args


def filters_to_sql_predicate(schema: pa.Schema, filters: NormalizedFilters) -> str:
if not filters:
return "true"
class RawCriterion(Criterion):
def __init__(self, expr: str) -> None:
super().__init__()
self.expr = expr

return " or ".join(
def get_sql(self, **kwargs: Any) -> str:
return self.expr


def filters_to_sql_predicate(
schema: pa.Schema, filters: NormalizedFilters
) -> Criterion:
return Criterion.any(
filters_to_sql_conjunction(schema, filter_set) for filter_set in filters
)


def filters_to_sql_conjunction(schema: pa.Schema, filters: list[Filter]) -> str:
if not filters:
return "true"

exprs = (filter_to_sql_expr(schema, f) for f in filters)
conjunction_expr = " and ".join(exprs)
return f"({conjunction_expr})"
def filters_to_sql_conjunction(schema: pa.Schema, filters: list[Filter]) -> Criterion:
return Criterion.all(filter_to_sql_expr(schema, f) for f in filters)


def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> str:
def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> Criterion:
column = f.column
if column not in schema.names:
raise ValueError(f"Invalid column name {column}")

column_type = schema.field(column).type
if f.operator in (
"=",
"!=",
"<",
"<=",
">",
">=",
"in",
"not in",
):
value_str = value_to_sql_expr(f.value, column_type)
return f"({column} {f.operator} {value_str})"

if f.operator == "=":
return Field(column) == f.value
elif f.operator == "!=":
return Field(column) != f.value
elif f.operator == "<":
return Field(column) < f.value
elif f.operator == "<=":
return Field(column) <= f.value
elif f.operator == ">":
return Field(column) > f.value
elif f.operator == ">=":
return Field(column) >= f.value
elif f.operator == "in":
return Field(column).isin(f.value)
elif f.operator == "not in":
return Field(column).notin(f.value)
elif f.operator == "contains":
assert isinstance(f.value, str)
escaped_str = escape_str_for_sql(f.value)
like_str = f"'%{escaped_str}%'"
return f"({column} like {like_str})"
return Field(column).like(f"%{f.value}%")

elif f.operator in ("includes", "includes any", "includes all"):
assert pa.types.is_list(column_type) or pa.types.is_large_list(column_type)
Expand All @@ -160,37 +165,18 @@ def filter_to_sql_expr(schema: pa.Schema, f: Filter) -> str:
else:
assert isinstance(f.value, list | tuple)
values = list(f.value)
assert values

# NOTE: for includes any/all, we join multiple array_contains with or/and
value_exprs = (
value_to_sql_expr(value, column_type.value_type) for value in values
)
include_exprs = (
f"array_contains({column}, {value_expr})" for value_expr in value_exprs
array_contains = CustomFunction("array_contains", ["table", "value"])
include_exprs = [array_contains(Field(column), value) for value in values]
final_expr = (
Criterion.all(include_exprs)
if f.operator == "includes all"
else Criterion.any(include_exprs)
)
join_operator = " or " if f.operator == "includes any" else " and "
conjunction_expr = join_operator.join(include_exprs)

return f"({conjunction_expr})"
return final_expr

else:
raise ValueError(f"Invalid operator {f.operator}")


def value_to_sql_expr(value: Any, value_type: pa.DataType) -> str:
if isinstance(value, list | tuple):
elements_str = ", ".join(
value_to_sql_expr(element, value_type) for element in value
)
value_str = f"({elements_str})"
else:
value_str = str(value)
# Escape the string so the user doesn't need to filter like ("col", "=", "'value'")
if pa.types.is_string(value_type):
escaped_str = escape_str_for_sql(value_str)
value_str = f"'{escaped_str}'"
return value_str


def escape_str_for_sql(value: str) -> str:
return value.replace("'", "''")
2 changes: 0 additions & 2 deletions client/neuralake/test/tables/test_deltalake_table.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
Expand All @@ -9,7 +8,6 @@
import pytest

from neuralake.core.tables.deltalake_table import (
DeltaCacheOptions,
DeltalakeTable,
Filter,
fetch_df_by_partition,
Expand Down
61 changes: 43 additions & 18 deletions client/neuralake/test/tables/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime, timezone

import pyarrow as pa
import pytest

Expand All @@ -18,6 +20,7 @@
("int_col", pa.int64()),
("list_col", pa.list_(pa.int64())),
("list_str_col", pa.list_(pa.string())),
("date_col", pa.timestamp("us", tz="UTC")),
]
)

Expand All @@ -26,54 +29,64 @@ class TestUtil:
@pytest.mark.parametrize(
("schema", "f", "expected"),
[
(test_schema, Filter("int_col", "=", 123), "(int_col = 123)"),
(test_schema, Filter("int_col", "=", "123"), "(int_col = 123)"),
(test_schema, Filter("int_col", "=", 123), '"int_col"=123'),
(test_schema, Filter("int_col", "=", "123"), "\"int_col\"='123'"),
# A tuple with a single element should not have a comma in SQL
(test_schema, Filter("int_col", "in", (1,)), "(int_col in (1))"),
(test_schema, Filter("int_col", "in", (1, 2)), "(int_col in (1, 2))"),
(test_schema, Filter("int_col", "in", (1,)), '"int_col" IN (1)'),
(test_schema, Filter("int_col", "in", (1, 2)), '"int_col" IN (1,2)'),
(
test_schema,
Filter("int_col", "not in", (1, 2)),
"(int_col not in (1, 2))",
'"int_col" NOT IN (1,2)',
),
# String columns should be handled to add single quotes
(test_schema, Filter("str_col", "=", "x"), "(str_col = 'x')"),
(test_schema, Filter("str_col", "in", ("val1",)), "(str_col in ('val1'))"),
(test_schema, Filter("str_col", "=", "x"), "\"str_col\"='x'"),
# Filtering using datetime columns
(
test_schema,
Filter("date_col", ">=", datetime(2024, 4, 5, tzinfo=timezone.utc)),
"\"date_col\">='2024-04-05T00:00:00+00:00'",
),
(
test_schema,
Filter("str_col", "in", ("val1",)),
"\"str_col\" IN ('val1')",
),
(
test_schema,
Filter("str_col", "in", ("val1", "val2")),
"(str_col in ('val1', 'val2'))",
"\"str_col\" IN ('val1','val2')",
),
(
test_schema,
Filter("str_col", "contains", "x'"),
"(str_col like '%x''%')",
"\"str_col\" LIKE '%x''%'",
),
# Test list columns
(
test_schema,
Filter("list_col", "includes", 1),
"(array_contains(list_col, 1))",
'array_contains("list_col",1)',
),
(
test_schema,
Filter("list_str_col", "includes", "x"),
"(array_contains(list_str_col, 'x'))",
"array_contains(\"list_str_col\",'x')",
),
(
test_schema,
Filter("list_col", "includes all", (1, 2, 3)),
"(array_contains(list_col, 1) and array_contains(list_col, 2) and array_contains(list_col, 3))",
'array_contains("list_col",1) AND array_contains("list_col",2) AND array_contains("list_col",3)',
),
(
test_schema,
Filter("list_col", "includes any", (1, 2, 3)),
"(array_contains(list_col, 1) or array_contains(list_col, 2) or array_contains(list_col, 3))",
'array_contains("list_col",1) OR array_contains("list_col",2) OR array_contains("list_col",3)',
),
],
)
def test_filter_to_expr(self, schema: pa.Schema, f: Filter, expected: str):
assert filter_to_sql_expr(schema, f) == expected
assert str(filter_to_sql_expr(schema, f)) == expected

def test_filter_to_expr_raises(self):
with pytest.raises(ValueError) as e:
Expand All @@ -84,26 +97,38 @@ def test_filter_to_expr_raises(self):
@pytest.mark.parametrize(
("schema", "filters", "expected"),
[
(test_schema, [[Filter("str_col", "=", "x")]], "((str_col = 'x'))"),
(test_schema, [[Filter("str_col", "=", "x")]], "\"str_col\"='x'"),
(
test_schema,
[[Filter("str_col", "=", "x"), Filter("int_col", "=", 123)]],
"((str_col = 'x') and (int_col = 123))",
'"str_col"=\'x\' AND "int_col"=123',
),
(
test_schema,
[
[Filter("str_col", "=", "x")],
[Filter("int_col", "=", 123), Filter("int_col", "<", 456)],
],
"((str_col = 'x')) or ((int_col = 123) and (int_col < 456))",
'"str_col"=\'x\' OR ("int_col"=123 AND "int_col"<456)',
),
(
test_schema,
# Ensures filters are properly nested and grouped
[
[Filter("str_col", "=", "x")],
[
Filter("list_col", "includes any", (1, 2, 3)),
Filter("list_col", "includes all", (1, 2, 3)),
],
],
'"str_col"=\'x\' OR ((array_contains("list_col",1) OR array_contains("list_col",2) OR array_contains("list_col",3)) AND array_contains("list_col",1) AND array_contains("list_col",2) AND array_contains("list_col",3))',
),
],
)
def test_filters_to_sql_predicate(
self, schema: pa.Schema, filters: NormalizedFilters, expected: str
):
assert filters_to_sql_predicate(schema, filters) == expected
assert str(filters_to_sql_predicate(schema, filters)) == expected

@pytest.mark.parametrize(
("filters", "expected"),
Expand Down
6 changes: 1 addition & 5 deletions client/neuralake/web_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@
from neuralake.core.tables.deltalake_table import DeltalakeTable
from neuralake.core.tables.metadata import TableProtocol
import json
import tempfile
import subprocess
from pathlib import Path
import os
import sys
import logging
import shutil
import importlib.resources

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr)
Expand Down Expand Up @@ -89,7 +85,7 @@ def export_and_generate_site(
logger.info(f"Copying precompiled directory {precompiled_dir} to {output_path}")
if not precompiled_dir.exists():
raise FileNotFoundError(
f"Could not find precompiled directory. Make sure you're running from the project root or the package is properly installed."
"Could not find precompiled directory. Make sure you're running from the project root or the package is properly installed."
)

shutil.copytree(precompiled_dir, output_path, dirs_exist_ok=True)
1 change: 1 addition & 0 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"pyarrow==17.0.0",
"ipython==8.5.0",
"typing_extensions==4.13.2",
"pypika>=0.48.0",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a huge dependency, sure we want to use it here? SQL injection is something we should check for when processing untrusted user content. I would expect internal code to be trusted code, after all, if the programmer is malicious, they can also submit code to by pass all these safety checks in the same PR.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't reviewed this PR yet but previously the attack vector was type confusion.

If a value was passed in that wasn't of a type that's supported it would get stringified and could then be susceptible.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, that's a better reason for doing more validation here, but can this be checked with type annotation instead?

]

[project.optional-dependencies]
Expand Down