Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
S3method("[",SBC_datasets)
S3method("[",SBC_results)
S3method(SBC_fit,SBC_backend_brms)
S3method(SBC_fit,SBC_backend_cmdstan_optimize)
S3method(SBC_fit,SBC_backend_cmdstan_sample)
S3method(SBC_fit,SBC_backend_cmdstan_variational)
S3method(SBC_fit,SBC_backend_rstan_sample)
S3method(SBC_fit_to_diagnostics,CmdStanMCMC)
S3method(SBC_fit_to_diagnostics,brmsfit)
S3method(SBC_fit_to_diagnostics,default)
S3method(SBC_fit_to_diagnostics,stanfit)
S3method(SBC_fit_to_draws_matrix,CmdStanMCMC)
S3method(SBC_fit_to_draws_matrix,CmdStanMLE)
S3method(SBC_fit_to_draws_matrix,CmdStanVB)
S3method(SBC_fit_to_draws_matrix,brmsfit)
S3method(SBC_fit_to_draws_matrix,default)
S3method(check_all_SBC_diagnostics,SBC_results)
Expand All @@ -33,7 +37,9 @@ S3method(summary,SBC_nuts_diagnostics)
S3method(summary,SBC_results)
export(SBC_backend_brms)
export(SBC_backend_brms_from_generator)
export(SBC_backend_cmdstan_optimize)
export(SBC_backend_cmdstan_sample)
export(SBC_backend_cmdstan_variational)
export(SBC_backend_rstan_sample)
export(SBC_datasets)
export(SBC_diagnostic_messages)
Expand Down
81 changes: 81 additions & 0 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,87 @@ SBC_fit_to_diagnostics.CmdStanMCMC <- function(fit, fit_output, fit_messages, fi
res
}

#' Backend based on variational approximation via `cmdstanr`.
#'
#' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`)
#' @param ... other arguments passed to the `$variational()` method of the model. The `data` and
#' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC
#' package.
#' @export
SBC_backend_cmdstan_variational <- function(model, ...) {
stopifnot(inherits(model, "CmdStanModel"))
if(length(model$exe_file()) == 0) {
stop("The model has to be already compiled, call $compile() first.")
}
args <- list(...)
unacceptable_params <- c("data", "parallel_chains ", "cores", "num_cores")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since variational is always run on single core, I guess data is the only argument we really want to forbid here.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a thread argument (and thread_per_chain argument for sample). Is this argument not relative to SBC? What is the difference using four parallel chains each with single threads vs one four threads for a single chain? This reduce_sum doc introduces thread as similar to parallel.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within-chain paralellization is another big can of worms here. I think it is sufficiently rare to let us expect people to just figure out the correct configuration (number of workers, cores_per_fit) themselves and pass the correct threading-related arguments to the backend - the configuration would likely would be very much use-case dependent, although most often NOT using any within-chain paralellization would be the best choice. We however definitely need to document this. I've started #49 to make sure we don't forget.

if(any(names(args) %in% unacceptable_params)) {
stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "),
" cannot be provided when defining a backend as they need to be set ",
"by the SBC package"))
}
structure(list(model = model, args = args), class = "SBC_backend_cmdstan_variational")
}

#' @export
SBC_fit.SBC_backend_cmdstan_variational <- function(backend, generated, cores) {
fit <- do.call(backend$model$variational,
combine_args(backend$args,
list(
data = generated)))

if(all(fit$return_codes() != 0)) {
stop("No chains finished succesfully")
}

fit
}

#' @export
SBC_fit_to_draws_matrix.CmdStanVB <- function(fit) {
fit$draws(format = "draws_matrix")

}

#' Backend based on optimize approximation via `cmdstanr`.
#'
#' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`)
#' @param ... other arguments passed to the `$optimize()` method of the model. The `data` and
#' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC
#' package.
#' @export
SBC_backend_cmdstan_optimize <- function(model, ...) {
stop("The optimize method is currently not supported.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to just remove the code for optimization now.

stopifnot(inherits(model, "CmdStanModel"))
if(length(model$exe_file()) == 0) {
stop("The model has to be already compiled, call $compile() first.")
}
args <- list(...)
unacceptable_params <- c("data")
if(any(names(args) %in% unacceptable_params)) {
stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "),
" cannot be provided when defining a backend as they need to be set ",
"by the SBC package"))
}
structure(list(model = model, args = args), class = "SBC_backend_cmdstan_optimize")
}
#' @export
SBC_fit.SBC_backend_cmdstan_optimize <- function(backend, generated, cores) {
fit <- do.call(backend$model$optimize,
combine_args(backend$args,
list(data = generated)))

if(all(fit$return_codes() != 0)) {
stop("Point optimization failed!")
}

fit
}
#' @export
SBC_fit_to_draws_matrix.CmdStanMLE <- function(fit) {
fit$draws(format = "draws_matrix")
}

# For internal use, creates brms backend.
new_SBC_backend_brms <- function(compiled_model,
args
Expand Down
8 changes: 4 additions & 4 deletions R/results.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ length.SBC_results <- function(x) {
#'
#' @param datasets an object of class `SBC_datasets`
#' @param backend the model + sampling algorithm. The built-in backends can be constructed
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_cmdstan_variational()`,`SBC_backend_cmdstan_optimize()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
#' (more to come). The backend is an S3 class supporting at least the `SBC_fit`,
#' `SBC_fit_to_draws_matrix` methods.
#' @param cores_per_fit how many cores should the backend be allowed to use for a single fit?
Expand Down Expand Up @@ -271,7 +271,6 @@ compute_results <- function(datasets, backend,
generated = datasets$generated[[i]]
)
}

if(is.null(gen_quants)) {
future.globals <- FALSE
} else {
Expand All @@ -288,7 +287,6 @@ compute_results <- function(datasets, backend,
future.globals = future.globals,
future.chunk.size = chunk_size)


# Combine, check and summarise
fits <- rep(list(NULL), length(datasets))
outputs <- rep(list(NULL), length(datasets))
Expand All @@ -307,7 +305,9 @@ compute_results <- function(datasets, backend,
stats_list[[i]] <- results_raw[[i]]$stats
stats_list[[i]]$dataset_id <- i
backend_diagnostics_list[[i]] <- results_raw[[i]]$backend_diagnostics
backend_diagnostics_list[[i]]$dataset_id <- i
if(!is.null(results_raw[[i]]$backend_diagnostics)){
backend_diagnostics_list[[i]]$dataset_id <- i
}
}
else {
if(n_errors < max_errors_to_show) {
Expand Down
18 changes: 18 additions & 0 deletions man/SBC_backend_cmdstan_optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/SBC_backend_cmdstan_variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/compute_results.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.