Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Imports:
cli,
data.table (>= 1.16.0),
ggplot2 (>= 3.4.0),
lifecycle,
methods,
purrr,
scoringRules (>= 1.1.3),
Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ S3method(assert_forecast,forecast_ordinal)
S3method(assert_forecast,forecast_point)
S3method(assert_forecast,forecast_quantile)
S3method(assert_forecast,forecast_sample)
S3method(get_forecast_type_ids,default)
S3method(get_forecast_type_ids,forecast_multivariate_sample)
S3method(get_forecast_type_ids,forecast_nominal)
S3method(get_forecast_type_ids,forecast_ordinal)
S3method(get_forecast_type_ids,forecast_quantile)
S3method(get_forecast_type_ids,forecast_sample)
S3method(get_metrics,forecast_binary)
S3method(get_metrics,forecast_multivariate_point)
S3method(get_metrics,forecast_multivariate_sample)
Expand Down Expand Up @@ -76,6 +82,7 @@ export(get_correlations)
export(get_coverage)
export(get_duplicate_forecasts)
export(get_forecast_counts)
export(get_forecast_type_ids)
export(get_forecast_unit)
export(get_grouping)
export(get_metrics)
Expand Down Expand Up @@ -208,6 +215,8 @@ importFrom(ggplot2,theme_minimal)
importFrom(ggplot2,unit)
importFrom(ggplot2,xlab)
importFrom(ggplot2,ylab)
importFrom(lifecycle,deprecate_warn)
importFrom(lifecycle,deprecated)
importFrom(methods,formalArgs)
importFrom(methods,hasArg)
importFrom(purrr,partial)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# scoringutils (development version)

- Added internal S3 generic `get_forecast_type_ids()` so each forecast type declares the columns (beyond the forecast unit) that identify a unique row. `get_duplicate_forecasts()` now uses this instead of hard-coded column names (#888).
- Removed the deprecated vignettes `Deprecated-functions` and `Deprecated-visualisations`. The code for removed functions (`plot_predictions()`, `make_NA()`, `plot_ranges()`, `plot_score_table()`, `merge_pred_and_obs()`) can still be found in the [git history](https://github.com/epiforecasts/scoringutils/tree/d0cd8e2/vignettes) (#1158).

# scoringutils 2.2.0
Expand Down
9 changes: 9 additions & 0 deletions R/class-forecast-multivariate-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ assert_forecast.forecast_multivariate_sample <- function(
}


#' @rdname get_forecast_type_ids
#' @export
# nolint start: object_name_linter
get_forecast_type_ids.forecast_multivariate_sample <- function(data) {
"sample_id"
}
# nolint end


#' @export
#' @rdname is_forecast
# nolint start: object_name_linter
Expand Down
7 changes: 7 additions & 0 deletions R/class-forecast-nominal.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ assert_forecast.forecast_nominal <- function(
}


#' @rdname get_forecast_type_ids
#' @export
get_forecast_type_ids.forecast_nominal <- function(data) {
"predicted_label"
}


#' @export
#' @rdname is_forecast
is_forecast_nominal <- function(x) {
Expand Down
7 changes: 7 additions & 0 deletions R/class-forecast-ordinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ assert_forecast.forecast_ordinal <- function(
}


#' @rdname get_forecast_type_ids
#' @export
get_forecast_type_ids.forecast_ordinal <- function(data) {
"predicted_label"
}


#' @export
#' @rdname is_forecast
is_forecast_ordinal <- function(x) {
Expand Down
7 changes: 7 additions & 0 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ assert_forecast.forecast_quantile <- function(
}


#' @rdname get_forecast_type_ids
#' @export
get_forecast_type_ids.forecast_quantile <- function(data) {
"quantile_level"
}


#' @export
#' @rdname is_forecast
is_forecast_quantile <- function(x) {
Expand Down
7 changes: 7 additions & 0 deletions R/class-forecast-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ assert_forecast.forecast_sample <- function(
}


#' @rdname get_forecast_type_ids
#' @export
get_forecast_type_ids.forecast_sample <- function(data) {
"sample_id"
}


#' @export
#' @rdname is_forecast
is_forecast_sample <- function(x) {
Expand Down
120 changes: 99 additions & 21 deletions R/get-duplicate-forecasts.R
Original file line number Diff line number Diff line change
@@ -1,53 +1,117 @@
#' @title Find duplicate forecasts
#'
#' @description
#' Internal helper function to identify duplicate forecasts, i.e.
#' instances where there is more than one forecast for the same prediction
#' target.
#' Identify duplicate forecasts, i.e. instances where there is more than
#' one forecast for the same prediction target.
#'
#' Uses [get_forecast_type_ids()] to determine the type-specific columns
#' (beyond the forecast unit) that identify a unique row. For `forecast`
#' objects the type is detected automatically. For plain `data.frame`s
#' you should pass `type` (e.g. `"quantile"`, `"sample"`) so that the
#' correct columns are used. Calling on a plain `data.frame` without
#' `type` is deprecated; it falls back to column-name detection but
#' this behaviour will be removed in a future version.
#'
#' @inheritParams as_forecast_doc_template
#' @param counts Should the output show the number of duplicates per forecast
#' unit instead of the individual duplicated rows? Default is `FALSE`.
#' @returns A data.frame with all rows for which a duplicate forecast was found
#' @param type Character string naming the forecast type, corresponding
#' to the class suffix after `forecast_` (e.g. `"quantile"` for
#' class `forecast_quantile`, `"sample"` for `forecast_sample`).
#' Used to determine type-specific ID columns when `data` is not
#' already a `forecast` object. Ignored when `data` already
#' inherits from `forecast`.
#' @param counts Should the output show the number of duplicates per
#' forecast unit instead of the individual duplicated rows?
#' Default is `FALSE`.
#' @returns A data.frame with all rows for which a duplicate forecast
#' was found
#' @export
#' @importFrom checkmate assert_data_frame assert_subset
#' @importFrom data.table setorderv
#' @importFrom lifecycle deprecated deprecate_warn
#' @keywords diagnose-inputs
#' @examples
#' example <- rbind(example_quantile, example_quantile[1000:1010])
#' get_duplicate_forecasts(example)
#' get_duplicate_forecasts(example, type = "quantile")
get_duplicate_forecasts <- function(
data,
forecast_unit = NULL,
type = NULL,
counts = FALSE
) {
assert_data_frame(data)
checkmate::assert_string(type, null.ok = TRUE)
data <- ensure_data.table(data)

if (!is.null(forecast_unit)) {
data <- set_forecast_unit(data, forecast_unit)
}
forecast_unit <- get_forecast_unit(data)
available_type <- c("sample_id", "quantile_level", "predicted_label") %in% colnames(data)
type <- c("sample_id", "quantile_level", "predicted_label")[available_type]

if (inherits(data, "forecast")) {
type_cols <- get_forecast_type_ids(data)
} else if (!is.null(type)) {
tmp <- new_forecast(data, paste0("forecast_", type))
Comment thread
seabbs marked this conversation as resolved.
type_cols <- get_forecast_type_ids(tmp)
} else {
lifecycle::deprecate_warn(
"2.2.0",
"get_duplicate_forecasts(type = )",
details = paste(
"Pass `type` (e.g. \"quantile\", \"sample\") to detect",
"type-specific duplicates on plain data.frames."
)
)
# deprecated fallback: detect type columns by name
known <- c("sample_id", "quantile_level", "predicted_label")
type_cols <- intersect(known, colnames(data))
}
data <- as.data.table(data)
data[, scoringutils_InternalDuplicateCheck := .N, by = c(forecast_unit, type)]
data[,
scoringutils_InternalDuplicateCheck := .N,
by = c(forecast_unit, type_cols)
]
out <- data[scoringutils_InternalDuplicateCheck > 1]

col <- colnames(data)[
colnames(data) %in% c("sample_id", "quantile_level", "predicted_label")
]
setorderv(out, cols = c(forecast_unit, col, "predicted"))
setorderv(out, cols = c(forecast_unit, type_cols, "predicted"))
out[, scoringutils_InternalDuplicateCheck := NULL]

if (counts) {
out <- out[, .(n_duplicates = .N), by = c(get_forecast_unit(out))]
out <- out[,
.(n_duplicates = .N),
by = c(get_forecast_unit(out))
]
}

return(out[])
}


#' @title Get type-specific ID columns for a forecast
#'
#' @description
#' S3 generic that returns the column names (beyond the forecast unit)
#' that identify a unique row for a given forecast type. Each forecast
#' type method returns the columns specific to that type. The default
#' returns `character(0)` (no type-specific columns).
#'
#' Custom forecast types should define a method returning the relevant
#' column names.
#'
#' @inheritParams as_forecast_doc_template
#' @returns A character vector of column names.
#' @export
#' @keywords as_forecast
get_forecast_type_ids <- function(data) {
UseMethod("get_forecast_type_ids")
}


#' @export
get_forecast_type_ids.default <- function(data) {
character(0)
}


#' Check that there are no duplicate forecasts
#'
#' @description
Expand All @@ -57,14 +121,28 @@ get_duplicate_forecasts <- function(
#' @inherit document_check_functions return
#' @keywords internal_input_check
check_duplicates <- function(data) {
check_duplicates <- get_duplicate_forecasts(data)
duplicates <- get_duplicate_forecasts(data)

if (nrow(check_duplicates) > 0) {
if (nrow(duplicates) > 0) {
hint_args <- ""
if (inherits(data, "forecast")) {
forecast_type <- get_forecast_type(data)
fu <- get_forecast_unit(data)
fu_str <- paste0(
"c(\"", paste(fu, collapse = "\", \""), "\")"
)
hint_args <- paste0(
", forecast_unit = ", fu_str,
", type = \"", forecast_type, "\""
)
}
msg <- paste0(
"There are instances with more than one forecast for the same target. ",
"This can't be right and needs to be resolved. Maybe you need to ",
"check the unit of a single forecast and add missing columns? Use ",
"the function get_duplicate_forecasts() to identify duplicate rows"
"There are instances with more than one forecast for the ",
"same target. This can't be right and needs to be resolved. ",
"Maybe you need to check the unit of a single forecast and ",
"add missing columns? Use ",
"`get_duplicate_forecasts(data", hint_args, ")` ",
"to identify duplicate rows"
)
return(msg)
}
Expand Down
37 changes: 29 additions & 8 deletions man/get_duplicate_forecasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions man/get_forecast_type_ids.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions tests/testthat/test-class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ test_that("as_forecast_quantile() function throws an error with duplicate foreca

expect_error(
suppressMessages(suppressWarnings(as_forecast_quantile(example))),
"Assertion on 'data' failed: There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows.", # nolint
fixed = TRUE
"There are instances with more than one forecast for the same target"
)
})

Expand Down
Loading
Loading