Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions packages/bigframes/bigframes/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,21 @@ class BigQuerySentEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = "fallback_to_global"
Comment thread
shuoweil marked this conversation as resolved.
Outdated

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QuerySentEvent,
progress_bar: Optional[str] = "fallback_to_global",
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -146,15 +152,21 @@ class BigQueryRetryEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = "fallback_to_global"

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QueryRetryEvent,
progress_bar: Optional[str] = "fallback_to_global",
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -171,10 +183,13 @@ class BigQueryReceivedEvent(ExecutionRunning):
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = "fallback_to_global"

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryReceivedEvent,
progress_bar: Optional[str] = "fallback_to_global",
):
return cls(
billing_project=event.billing_project,
Expand All @@ -186,6 +201,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand All @@ -204,10 +220,13 @@ class BigQueryFinishedEvent(ExecutionRunning):
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = "fallback_to_global"

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryFinishedEvent,
progress_bar: Optional[str] = "fallback_to_global",
):
return cls(
billing_project=event.billing_project,
Expand All @@ -221,6 +240,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand Down
5 changes: 4 additions & 1 deletion packages/bigframes/bigframes/formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def progress_callback(
# This will allow cleanup to continue.
return

progress_bar = bigframes._config.options.display.progress_bar
# Prioritize progress_bar set on the event, falling back to thread-local option.
progress_bar = getattr(event, "progress_bar", "fallback_to_global")
if progress_bar == "fallback_to_global":
progress_bar = bigframes._config.options.display.progress_bar

if progress_bar == "auto":
progress_bar = "notebook" if in_ipython() else "terminal"
Expand Down
32 changes: 24 additions & 8 deletions packages/bigframes/bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,27 @@ def add_and_trim_labels(job_config, session=None):


def create_bq_event_callback(publisher):
import bigframes._config

progress_bar = bigframes._config.options.display.progress_bar

def publish_bq_event(event):
if isinstance(event, google.cloud.bigquery._job_helpers.QueryFinishedEvent):
bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(event)
bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(
event, progress_bar=progress_bar
)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryReceivedEvent):
bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(event)
bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(
event, progress_bar=progress_bar
)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryRetryEvent):
bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(event)
bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(
event, progress_bar=progress_bar
)
elif isinstance(event, google.cloud.bigquery._job_helpers.QuerySentEvent):
bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(event)
bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(
event, progress_bar=progress_bar
)
else:
bf_event = bigframes.core.events.BigQueryUnknownEvent(event)
Comment thread
shuoweil marked this conversation as resolved.
Outdated

Expand All @@ -275,7 +287,8 @@ def start_query_with_client(
query_with_job: Literal[True],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...


@overload
Expand All @@ -291,7 +304,8 @@ def start_query_with_client(
query_with_job: Literal[False],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...


@overload
Expand All @@ -308,7 +322,8 @@ def start_query_with_client(
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...


@overload
Expand All @@ -325,7 +340,8 @@ def start_query_with_client(
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...


def start_query_with_client(
Expand Down
25 changes: 25 additions & 0 deletions packages/bigframes/tests/unit/test_formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,28 @@ def test_get_job_url():
job_id=job_id, location=location, project_id=project_id
)
assert actual_url == expected_url


def test_progress_callback_respects_event_progress_bar():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
progress_bar=None,
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_not_called()


def test_progress_callback_falls_back_to_global():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_called_once()
Loading