Skip to content
This repository was archived by the owner on Jul 3, 2023. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions examples/pyspark_native/possible_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from hamilton.function_modifiers import extract_columns, tag


# Data Loading
# Filtering is part of data loading -- do we also expose columns like this?
@extract_columns(
*["l_quantity", "l_extendedprice", "l_discount", "l_tax", "l_returnflag", "l_linestatus"]
)
def lineitem(
sc: SparkSession,
path: str,
filter: str = "l_shipdate <= date '1998-12-01' - interval '90' day",
) -> pyspark.sql.DataFrame:
"""Loads and filters data from the lineitem table"""
ds: pyspark.sql.DataFrame = (
sc.read.option("inferSchema", True).option("header", True).csv(path, sep="|")
)
if filter:
ds = ds.filter(filter)
print(ds.schema)
return ds


# transforms we want people to write
def disc_price(l_extendedprice: float, l_discount: float) -> float:
"""Computes the discounted price"""
return l_extendedprice * (1 - l_discount)


def charge(l_extendedprice: float, l_discount: float, l_tax: float) -> float:
"""Computes the charge"""
return l_extendedprice * (1 - l_discount) * (1 + l_tax)


# hacking things in via tags
@tag(group_by="l_returnflag,l_linestatus")
def grouped_lineitem(
l_quantity: pyspark.sql.Column,
l_extendedprice: pyspark.sql.Column,
disc_price: pyspark.sql.Column, # do we do some optional syntax here?
# and at run time we check if the column exists, and if so use it. Else skip it.
# basically it means that someone could write one function and determine if all are required or not.
# and save them from having to update this function if they don't want a particular column being passed
# through -- thought downstream of this, they would have to deal with it... so maybe not that valuable?
charge: pyspark.sql.Column,
l_discount: pyspark.sql.Column,
l_returnflag: pyspark.sql.Column,
l_linestatus: pyspark.sql.Column,
) -> pyspark.sql.GroupedData:
"""This function declares the "schema" the datastream needs to have via it's arguments.
The body is blank because the graph adapter knows the actual logic to perform.

Alternate syntax could be to have the decorator declare what's required, the function then takes the datastream, the
group by happens in the function...
"""
pass


# hack to get around https://github.com/marsupialtail/quokka/issues/23
# @tag(materialize="True")
def compute_aggregates(grouped_lineitem: pyspark.sql.GroupedData) -> pyspark.sql.DataFrame:
# Thought: change these into individual functions?
agg_map = {
"l_quantity": ["sum", "avg"],
"l_extendedprice": ["sum", "avg"],
"disc_price": ["sum"],
"charge": ["sum"],
"l_discount": ["avg"],
"*": ["count"],
}
agg_args = []
for column_name, aggregates in agg_map.items():
for aggregate in aggregates:
func = getattr(F, aggregate)
agg_args.append(func(F.column(column_name)))
df = grouped_lineitem.agg(
# {
# # "l_quantity": ["sum", "avg"],
# "l_quantity": "sum",
# # "l_extendedprice": ["sum", "avg"],
# "l_extendedprice": "avg",
# "disc_price": "sum",
# "charge": "sum",
# "l_discount": "avg",
# "*": "count",
# }
*agg_args
)
rename_map = {
"l_returnflag": "al_returnflag",
"l_linestatus": "al_linestatus",
"count(1)": "row_count",
"avg(l_quantity)": "l_quantity_mean",
"sum(l_quantity)": "l_quantity_sum",
"avg(l_extendedprice)": "l_extendedprice_mean",
"sum(l_extendedprice)": "l_extendedprice_sum",
"sum(disc_price)": "disc_price_sum",
"sum(charge)": "charge_sum",
"avg(l_discount)": "l_discount_mean",
}
for old_name, new_name in rename_map.items():
df = df.withColumnRenamed(old_name, new_name)
return df


# this doesn't seem like the right thing:
# def l_quantity_sum(grouped_lineitem: GroupedDataStream) -> pyspark.sql.Column:
# pass


# hack to get around `@tag_outputs` not working with `@extract_columns` as expected.
@extract_columns(
*[
"row_count",
"l_quantity_sum",
"l_extendedprice_sum",
"disc_price_sum",
"charge_sum",
"l_quantity_mean",
"l_extendedprice_mean",
"l_discount_mean",
"al_returnflag",
"al_linestatus",
]
)
def extract_aggregates(compute_aggregates: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
return compute_aggregates
115 changes: 115 additions & 0 deletions examples/pyspark_native/pyspark_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import inspect
from typing import Any, Callable, Dict, Tuple, Type

from pyspark.sql import Column, DataFrame, GroupedData, types
from pyspark.sql.functions import column, udf

from hamilton import base, node


class PySparkGraphAdapter(base.SimplePythonDataFrameGraphAdapter):
def __init__(self, result_builder: base.ResultMixin = base.DictResult()):
self.df_objects = {}
self.call_count = 0
self.result_builder = result_builder

@staticmethod
def check_input_type(node_type: Type, input_value: Any) -> bool:
return True

@staticmethod
def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool:
return True

def _lambda_udf(self, df: DataFrame, hamilton_udf: Callable) -> DataFrame:
sig = inspect.signature(hamilton_udf)
input_parameters = dict(sig.parameters)
return_type = sig.return_annotation
print("lambda inputs", input_parameters, return_type, hamilton_udf.__name__)
if return_type == float:
spark_return_type = types.DoubleType()
else:
raise ValueError(f"Unsupported return type {return_type}")
spark_udf = udf(hamilton_udf, spark_return_type)
return df.withColumn(
hamilton_udf.__name__, spark_udf(*[column(name) for name in sig.parameters.keys()])
)

def _sanitize_kwargs(self, kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Sanitizes the kwargs to remove the datastream node name."""
df_names = {}
actual_kwargs = {}
for kwarg_key, kwarg_value in kwargs.items():
if isinstance(kwarg_value, dict) and "__df_name__" in kwarg_value:
df_names[kwarg_key] = kwarg_value["__df_name__"]
actual_kwargs[kwarg_key] = kwarg_value["__result__"]
else:
actual_kwargs[kwarg_key] = kwarg_value
return actual_kwargs, df_names

def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any:
self.call_count += 1
actual_kwargs, df_names = self._sanitize_kwargs(kwargs)
df_name_set = set(df_names.values())
if node.type == DataFrame:
# assumption is types into this function are only scalars, or other DataFrame/GroupedData objects
df: DataFrame = node.callable(**actual_kwargs)
self.df_objects[node.name] = df
return {"__df_name__": node.name, "__result__": df}
elif node.type == GroupedData:
print("got group by", self.call_count, node.name)
assert len(df_name_set) == 1, f"Error groupby got multiple DataFrames {df_names}"
df = self.df_objects[df_names.popitem()[1]]
print("before select", df.schema)
df = df.select(list(node.input_types.keys()))
print("after select", df.schema)
group_by_cols = node.tags["group_by"].split(",")
df: GroupedData = df.groupby(group_by_cols)
self.df_objects[node.name] = df
return {"__df_name__": node.name, "__result__": df}
elif (
node.type == Column
and len(node.input_types) == 1
and node.tags.get("__generated_by__", None) == "extract_columns"
):
assert (
len(df_name_set) == 1
), f"Error extract_columns got multiple DataFrames {df_names}"
df_name = df_name_set.pop()
print(node.name, node.tags)
# print('got extract_columns', self.call_count, node.name, kwargs)
# get global one
df = self.df_objects[df_name]
return {"__df_name__": df_name, "__result__": df}
else:
assert len(df_name_set) == 1, f"Error udf got multiple DataFrames {df_names}"
print("got udf", self.call_count, node.name, kwargs)
df_name = df_name_set.pop()
df: DataFrame = self.df_objects[df_name]
print(df.schema)
df = self._lambda_udf(df, node.callable)
self.df_objects[df_name] = df
return {"__df_name__": df_name, "__result__": df}

def build_result(self, **outputs: Dict[str, Any]) -> DataFrame:
"""Builds the result and brings it back to this running process.

:param outputs: the dictionary of key -> Union[ray object reference | value]
:return: The type of object returned by self.result_builder.
"""
requested_ds_set = set(outputs.keys())
actual_outputs, df_names = self._sanitize_kwargs(outputs)
df_name_set = set(df_names.values())
assert (
len(df_name_set) == 1
), f"Error got multiple DataStreams to build result from {df_names}"
df = self.df_objects[df_name_set.pop()]
df = df.select(list(actual_outputs.keys()))
global_ds_set = set(df.columns)
if requested_ds_set.intersection(global_ds_set) != requested_ds_set:
raise ValueError(
f"Error: requested columns not found in final dataframe. "
f"Missing: {requested_ds_set.difference(global_ds_set)}."
)
print("final schema", df.schema)
return df
2 changes: 2 additions & 0 deletions examples/pyspark_native/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pyspark
sf-hamilton
57 changes: 57 additions & 0 deletions examples/pyspark_native/run_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import possible_api
import pyspark_adapter
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types

from hamilton import driver, log_setup


def disc_price(l_extendedprice, l_discount) -> float:
"""Computes the discounted price"""
return l_extendedprice * (1 - l_discount)


def grr():
path = "/Users/stefankrawczyk/Downloads/tpc-h-public/lineitem.tbl"
sc = SparkSession.builder.getOrCreate()
sc.sparkContext.setLogLevel("WARN")
log_setup.setup_logging()
df: DataFrame = sc.read.option("inferSchema", True).option("header", True).csv(path, sep="|")
df = df.filter("l_shipdate <= date '1998-12-01' - interval '90' day")
print(df.schema)
spark_udf = F.udf(disc_price, types.FloatType())
df.withColumn("disc_price", spark_udf(df["l_extendedprice"], df["l_discount"])).show()


def main():
log_setup.setup_logging(log_level=log_setup.LOG_LEVELS["INFO"])
spark = SparkSession.builder.getOrCreate()
path = "/Users/stefankrawczyk/Downloads/tpc-h-public/lineitem.tbl"
adapter = pyspark_adapter.PySparkGraphAdapter()
dr = driver.Driver({"sc": spark, "path": path}, possible_api, adapter=adapter)
outputs = [
"al_returnflag", # can comment one of these out and they are dropped
"al_linestatus",
"row_count",
"charge_sum",
"disc_price_sum",
"l_discount_mean",
"l_extendedprice_mean",
"l_extendedprice_sum",
"l_quantity_mean",
"l_quantity_sum",
# "grouped_lineitem" -- calling collect on groups doesn't work it seems...
# "l_returnflag", # can also get intermediate results
# "charge",
# "disc_price",
]
dr.visualize_execution(outputs, "./my_dag.dot", {})
df = dr.execute(outputs)
df.show()
spark.stop()


if __name__ == "__main__":
main()
# grr()
Loading