Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export(get_forecast_counts)
export(get_forecast_unit)
export(get_grouping)
export(get_metrics)
export(get_non_monotonic_forecasts)
export(get_pairwise_comparisons)
export(get_pit_histogram)
export(interval_coverage)
Expand Down
7 changes: 7 additions & 0 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ assert_forecast.forecast_quantile <- function(
forecast <- assert_forecast_generic(forecast, verbose)
assert_forecast_type(forecast, actual = "quantile", desired = forecast_type)
assert_numeric(forecast$quantile_level, lower = 0, upper = 1)

# check for non-monotonic predictions
monotonicity_check <- check_monotonicity(forecast)
if (!isTRUE(monotonicity_check) && verbose) {
cli_warn(monotonicity_check)
}

return(invisible(NULL))
}

Expand Down
88 changes: 88 additions & 0 deletions R/get-non-monotonic-forecasts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#' @title Find non-monotonic quantile predictions
#'
#' @description
#' Identifies forecast units where predicted values are not monotonically
#' non-decreasing with increasing quantile levels. This is a diagnostic helper
#' function modeled on [get_duplicate_forecasts()].
#'
#' Quantile forecasts should have non-decreasing predictions as quantile levels
#' increase (e.g., the predicted value at the 0.75 quantile should be at least
#' as large as the predicted value at the 0.5 quantile). Violations indicate
#' a problem with the forecasting method.
#'
#' @inheritParams as_forecast_doc_template
#' @param counts Should the output show the number of quantile rows per
#' affected forecast unit instead of the individual rows? Default is `FALSE`.
#' @returns A data.table with all rows for which a non-monotonic prediction
#' was found
#' @export
#' @importFrom checkmate assert_data_frame
#' @importFrom data.table setorderv
#' @keywords diagnose-inputs
#' @examples
#' # well-formed data returns 0 rows
#' get_non_monotonic_forecasts(example_quantile)
#'
#' # non-monotonic data
#' bad_data <- data.frame(
#' model = "m1", date = as.Date("2020-01-01"), observed = 5,
#' quantile_level = c(0.25, 0.5, 0.75), predicted = c(3, 7, 4)
#' )
#' get_non_monotonic_forecasts(bad_data)
get_non_monotonic_forecasts <- function(
data,
forecast_unit = NULL,
counts = FALSE
) {
assert_data_frame(data)
data <- ensure_data.table(data)

if (!is.null(forecast_unit)) {
data <- set_forecast_unit(data, forecast_unit)
}
forecast_unit <- get_forecast_unit(data)

data <- as.data.table(data)
setorderv(data, cols = c(forecast_unit, "quantile_level"))

# For each forecast unit, check if predictions are non-decreasing
data[, scoringutils_InternalMonoCheck := {
ordered_pred <- predicted[order(quantile_level)]
any(diff(ordered_pred) < 0)
}, by = forecast_unit]

out <- data[scoringutils_InternalMonoCheck == TRUE]

Check warning on line 54 in R/get-non-monotonic-forecasts.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/get-non-monotonic-forecasts.R,line=54,col=15,[redundant_equals_linter] Using == on a logical vector is redundant. Well-named logical vectors can be used directly in filtering. For data.table's `i` argument, wrap the column name in (), like `DT[(is_treatment)]`.
out[, scoringutils_InternalMonoCheck := NULL]
data[, scoringutils_InternalMonoCheck := NULL]

if (counts) {
out <- out[, .(n_non_monotonic = .N), by = c(get_forecast_unit(out))]
}

return(out[])
}


#' Check that predictions are monotonically non-decreasing with quantile level
#'
#' @description
#' Runs [get_non_monotonic_forecasts()] and returns a message if an issue is
#' encountered
#' @inheritParams get_non_monotonic_forecasts
#' @inherit document_check_functions return
#' @keywords internal_input_check
check_monotonicity <- function(data) {
non_monotonic <- get_non_monotonic_forecasts(data)

if (nrow(non_monotonic) > 0) {
msg <- paste0(
"Some forecasts have predictions that are not monotonically ",
"non-decreasing with increasing quantile level. ",
"This may cause issues with some scoring metrics (e.g. bias). ",
"Use the function get_non_monotonic_forecasts() to identify ",
"affected forecast units."
)
return(msg)
}
return(TRUE)
}
1 change: 1 addition & 0 deletions R/z-globalVariables.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ globalVariables(c(
"rn",
"sample_id",
"scoringutils_InternalDuplicateCheck",
"scoringutils_InternalMonoCheck",
"scoringutils_InternalNumCheck",
"se_mean",
"sharpness",
Expand Down
22 changes: 22 additions & 0 deletions man/check_monotonicity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions man/get_non_monotonic_forecasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions tests/testthat/test-class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,66 @@ test_that("as_forecast_quantile handles rounding issues correctly", {



# ==============================================================================
# assert_forecast.forecast_quantile() monotonicity warning
# ==============================================================================

test_that("assert_forecast.forecast_quantile() warns about non-monotonic predictions", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
expect_warning(
as_forecast_quantile(data),
"predictions that are not monotonically non-decreasing"
)
# Should still succeed and return a valid forecast_quantile object
result <- suppressWarnings(as_forecast_quantile(data))
expect_s3_class(result, "forecast_quantile")
})

test_that("assert_forecast.forecast_quantile() does not warn for well-formed data", {
test <- na.omit(data.table::copy(example_quantile))
expect_no_condition(
as_forecast_quantile(test,
forecast_unit = c(
"location", "model", "target_type",
"target_end_date", "horizon"
)
)
)
})

test_that("assert_forecast.forecast_quantile() suppresses monotonicity warning when verbose = FALSE", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
forecast_obj <- suppressWarnings(as_forecast_quantile(data))
expect_no_condition(assert_forecast(forecast_obj, verbose = FALSE))
})

test_that("score() works end-to-end with non-monotonic predictions when bias is excluded", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
data <- suppressWarnings(as_forecast_quantile(data))
result <- score(data, metrics = list(wis = wis))
expect_s3_class(result, "scores")
expect_true("wis" %in% colnames(result))
})


# ==============================================================================
# is_forecast_quantile() # nolint: commented_code_linter
# ==============================================================================
Expand Down
123 changes: 123 additions & 0 deletions tests/testthat/test-get-non-monotonic-forecasts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# ==============================================================================
# get_non_monotonic_forecasts() # nolint: commented_code_linter
# ==============================================================================

test_that("get_non_monotonic_forecasts() returns empty data.table for well-formed quantile forecasts", {
result <- get_non_monotonic_forecasts(example_quantile)
expect_identical(nrow(result), 0L)
expect_s3_class(result, c("data.table", "data.frame"))
})

test_that("get_non_monotonic_forecasts() detects predictions that decrease with increasing quantile level", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
data <- suppressWarnings(suppressMessages(as_forecast_quantile(data)))
result <- get_non_monotonic_forecasts(data)
expect_gt(nrow(result), 0L)
expect_identical(nrow(result), 3L)
})

test_that("get_non_monotonic_forecasts() handles multiple forecast units with mixed monotonicity", {
data <- data.table(
model = rep(c("good_model", "bad_model"), each = 3),
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = rep(c(0.25, 0.5, 0.75), 2),
predicted = c(2, 5, 8, 3, 7, 4)
)
data <- suppressWarnings(suppressMessages(as_forecast_quantile(data)))
result <- get_non_monotonic_forecasts(data)
expect_identical(nrow(result), 3L)
expect_identical(unique(result$model), "bad_model")
})

test_that("get_non_monotonic_forecasts() works with counts argument", {
data <- data.table(
model = rep(c("bad1", "bad2", "good"), each = 3),
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = rep(c(0.25, 0.5, 0.75), 3),
predicted = c(3, 7, 4, 5, 9, 6, 2, 5, 8)
)
data <- suppressWarnings(suppressMessages(as_forecast_quantile(data)))
result <- get_non_monotonic_forecasts(data, counts = TRUE)
expect_identical(nrow(result), 2L)
})

test_that("get_non_monotonic_forecasts() accepts custom forecast_unit argument", {
expect_no_condition(
get_non_monotonic_forecasts(
example_quantile,
forecast_unit = c(
"location", "target_end_date", "target_type",
"location_name", "forecast_date", "model"
)
)
)
result <- get_non_monotonic_forecasts(
example_quantile,
forecast_unit = c(
"location", "target_end_date", "target_type",
"location_name", "forecast_date", "model"
)
)
expect_identical(nrow(result), 0L)
})

test_that("get_non_monotonic_forecasts() returns expected class", {
result <- get_non_monotonic_forecasts(example_quantile)
expect_s3_class(result, c("data.table", "data.frame"))
})

test_that("get_non_monotonic_forecasts() works with a plain data.frame input", {
data <- data.frame(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
result <- get_non_monotonic_forecasts(data)
expect_gt(nrow(result), 0L)
expect_s3_class(result, c("data.table", "data.frame"))
})

test_that("get_non_monotonic_forecasts() handles equal predictions at adjacent quantile levels", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 5, 5)
)
data <- suppressWarnings(suppressMessages(as_forecast_quantile(data)))
result <- get_non_monotonic_forecasts(data)
expect_identical(nrow(result), 0L)
})


# ==============================================================================
# check_monotonicity() # nolint: commented_code_linter
# ==============================================================================

test_that("check_monotonicity() returns TRUE for well-formed data", {
expect_true(check_monotonicity(example_quantile))
})

test_that("check_monotonicity() returns message string for non-monotonic data", {
data <- data.table(
model = "model1",
date = as.Date("2020-01-01"),
observed = 5,
quantile_level = c(0.25, 0.5, 0.75),
predicted = c(3, 7, 4)
)
data <- suppressWarnings(suppressMessages(as_forecast_quantile(data)))
result <- check_monotonicity(data)
expect_match(result, "non-monotonic|decrease|get_non_monotonic_forecasts", ignore.case = TRUE)
})
Loading