-
Notifications
You must be signed in to change notification settings - Fork 126
Add global option to skip households on simulation failure #1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
a0ec6a8
8623516
397f250
fa08f11
f26cc80
76fd833
96e73ec
cf98cd2
baa47fa
6cffd9b
0e34b7d
5da9715
50034fd
a6ed5cb
91227b6
c986e5b
e6a8c1b
ee52916
5316890
f0a2581
f04bf14
9d3d018
06c6d75
c257e3c
a734c16
d28faa0
71cc5fd
dcb2864
0615cfc
93c3ba9
96e01f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,6 +307,12 @@ def write_trip_matrices( | |
| .TAZ.tolist() | ||
| ) | ||
|
|
||
| # print out number of households skipped due to failed choices | ||
| if state.settings.skip_failed_choices: | ||
| logger.info( | ||
|
||
| f"\n!!!\nATTENTION: Skipped households with failed choices during simulation. Number of households skipped: {state.get('num_skipped_households', 0)}.\n!!!" | ||
| ) | ||
|
|
||
|
|
||
| def annotate_trips( | ||
| state: workflow.State, | ||
|
|
@@ -393,6 +399,21 @@ def write_matrices( | |
| if not matrix_settings: | ||
| logger.error("Missing MATRICES setting in write_trip_matrices.yaml") | ||
|
|
||
| hh_weight_col = model_settings.HH_EXPANSION_WEIGHT_COL | ||
| if hh_weight_col: | ||
| if state.get("num_skipped_households", 0) > 0: | ||
| logger.info( | ||
| f"Adjusting household expansion weights in {hh_weight_col} to account for {state.get('num_skipped_households', 0)} skipped households." | ||
| ) | ||
| # adjust the hh expansion weights to account for skipped households | ||
| adjustment_factor = state.get_dataframe("households").shape[0] / ( | ||
|
||
| state.get_dataframe("households").shape[0] | ||
| + state.get("num_skipped_households", 0) | ||
| ) | ||
| aggregate_trips[hh_weight_col] = ( | ||
| aggregate_trips[hh_weight_col] * adjustment_factor | ||
| ) | ||
i-am-sijia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| for matrix in matrix_settings: | ||
| matrix_is_tap = matrix.is_tap | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1043,6 +1043,10 @@ def force_escortee_trip_modes_to_match_chauffeur(state: workflow.State, trips): | |
| f"Changed {diff.sum()} trip modes of school escortees to match their chauffeur" | ||
| ) | ||
|
|
||
| # trip_mode can be na if the run allows skipping failed choices and the trip mode choice has failed | ||
| if state.settings.skip_failed_choices: | ||
i-am-sijia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return trips | ||
|
||
|
|
||
| assert ( | ||
| ~trips.trip_mode.isna() | ||
| ).all(), f"Missing trip mode for {trips[trips.trip_mode.isna()]}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -781,6 +781,13 @@ def _check_store_skims_in_shm(self): | |
| should catch many common errors early, including missing required configurations or specified coefficient labels without defined values. | ||
i-am-sijia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| skip_failed_choices: bool = True | ||
| """ | ||
| Skip households that cause errors during processing instead of failing the model run. | ||
|
|
||
| .. versionadded:: 1.6 | ||
| """ | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need additional setting[s] to set thresholds for how many skips are OK and when it's too many and should be an error.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| other_settings: dict[str, Any] = None | ||
|
|
||
| def _get_attr(self, attr): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ def report_bad_choices( | |
| state: workflow.State, | ||
| bad_row_map, | ||
| df, | ||
| skip_failed_choices, | ||
| trace_label, | ||
| msg, | ||
| trace_choosers=None, | ||
|
|
@@ -87,6 +88,27 @@ def report_bad_choices( | |
|
|
||
| logger.warning(row_msg) | ||
|
|
||
| if skip_failed_choices: | ||
| # update counter in state | ||
| num_skipped_households = state.get("num_skipped_households", 0) | ||
| skipped_household_ids = state.get("skipped_household_ids", set()) | ||
| for hh_id in df[trace_col].unique(): | ||
| if hh_id is None: | ||
| continue | ||
| if hh_id not in skipped_household_ids: | ||
| skipped_household_ids.add(hh_id) | ||
| num_skipped_households += 1 | ||
| else: | ||
| continue | ||
i-am-sijia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| state.set("num_skipped_households", num_skipped_households) | ||
| state.set("skipped_household_ids", skipped_household_ids) | ||
|
|
||
| logger.debug( | ||
|
||
| f"Skipping {bad_row_map.sum()} bad choices. Total skipped households so far: {num_skipped_households}. Skipped household IDs: {skipped_household_ids}" | ||
| ) | ||
|
|
||
| return | ||
|
|
||
| if raise_error: | ||
| raise InvalidTravelError(msg_with_count) | ||
|
|
||
|
|
@@ -136,6 +158,7 @@ def utils_to_probs( | |
| allow_zero_probs=False, | ||
| trace_choosers=None, | ||
| overflow_protection: bool = True, | ||
| skip_failed_choices: bool = True, | ||
| return_logsums: bool = False, | ||
| ): | ||
| """ | ||
|
|
@@ -176,6 +199,16 @@ def utils_to_probs( | |
| overflow_protection will have no benefit but impose a modest computational | ||
| overhead cost. | ||
|
|
||
| skip_failed_choices : bool, default True | ||
| If True, when bad choices are detected (all zero probabilities or infinite | ||
| probabilities), the entire household that's causing bad choices will be skipped instead of | ||
| being masked by overflow protection or causing an error. | ||
| A counter will be incremented for each skipped household. This is useful when running large | ||
| simulations where occasional bad choices are encountered and should not halt the process. | ||
| The counter can be accessed via `state.get("num_skipped_households", 0)`. | ||
| The number of skipped households and their IDs will be logged at the end of the simulation. | ||
| When `skip_failed_choices` is True, `overflow_protection` will be reverted to False to avoid conflicts. | ||
i-am-sijia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Returns | ||
| ------- | ||
| probs : pandas.DataFrame | ||
|
|
@@ -203,6 +236,13 @@ def utils_to_probs( | |
| utils_arr.dtype == np.float32 and utils_arr.max() > 85 | ||
| ) | ||
|
|
||
| if state.settings.skip_failed_choices is not None: | ||
| skip_failed_choices = state.settings.skip_failed_choices | ||
| # when skipping failed choices, we cannot use overflow protection | ||
| # because it would mask the underlying issue causing bad choices | ||
| if skip_failed_choices: | ||
| overflow_protection = False | ||
|
|
||
| if overflow_protection: | ||
| # exponentiated utils will overflow, downshift them | ||
| shifts = utils_arr.max(1, keepdims=True) | ||
|
|
@@ -240,6 +280,7 @@ def utils_to_probs( | |
| state, | ||
| zero_probs, | ||
| utils, | ||
| skip_failed_choices, | ||
| trace_label=tracing.extend_trace_label(trace_label, "zero_prob_utils"), | ||
| msg="all probabilities are zero", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -251,6 +292,7 @@ def utils_to_probs( | |
| state, | ||
| inf_utils, | ||
| utils, | ||
| skip_failed_choices, | ||
| trace_label=tracing.extend_trace_label(trace_label, "inf_exp_utils"), | ||
| msg="infinite exponentiated utilities", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -281,6 +323,7 @@ def make_choices( | |
| trace_label: str = None, | ||
| trace_choosers=None, | ||
| allow_bad_probs=False, | ||
| skip_failed_choices=True, | ||
| ) -> tuple[pd.Series, pd.Series]: | ||
| """ | ||
| Make choices for each chooser from among a set of alternatives. | ||
|
|
@@ -316,11 +359,15 @@ def make_choices( | |
| np.ones(len(probs.index)) | ||
| ).abs() > BAD_PROB_THRESHOLD * np.ones(len(probs.index)) | ||
|
|
||
| if state.settings.skip_failed_choices is not None: | ||
| skip_failed_choices = state.settings.skip_failed_choices | ||
|
|
||
| if bad_probs.any() and not allow_bad_probs: | ||
| report_bad_choices( | ||
| state, | ||
| bad_probs, | ||
| probs, | ||
| skip_failed_choices, | ||
i-am-sijia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| trace_label=tracing.extend_trace_label(trace_label, "bad_probs"), | ||
| msg="probabilities do not add up to 1", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -329,6 +376,8 @@ def make_choices( | |
| rands = state.get_rn_generator().random_for_df(probs) | ||
|
|
||
| choices = pd.Series(choice_maker(probs.values, rands), index=probs.index) | ||
| # mark bad choices with -99 | ||
| choices[bad_probs] = -99 | ||
|
|
||
| rands = pd.Series(np.asanyarray(rands).flatten(), index=probs.index) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should check here how many choices are getting dropped, and
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the warning and the check have been implemented in
core.workflow.state.update_table(). I implemented them in the state.py instead of in model components to make the code base clean, i.e., less code duplication.