Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ S3method("[",SBC_datasets)
S3method("[",SBC_results)
S3method(SBC_fit,SBC_backend_brms)
S3method(SBC_fit,SBC_backend_cmdstan_sample)
S3method(SBC_fit,SBC_backend_cmdstan_variational)
S3method(SBC_fit,SBC_backend_rstan_sample)
S3method(SBC_fit_to_diagnostics,CmdStanMCMC)
S3method(SBC_fit_to_diagnostics,brmsfit)
S3method(SBC_fit_to_diagnostics,default)
S3method(SBC_fit_to_diagnostics,stanfit)
S3method(SBC_fit_to_draws_matrix,CmdStanMCMC)
S3method(SBC_fit_to_draws_matrix,CmdStanVB)
S3method(SBC_fit_to_draws_matrix,brmsfit)
S3method(SBC_fit_to_draws_matrix,default)
S3method(check_all_SBC_diagnostics,SBC_results)
Expand All @@ -34,6 +36,7 @@ S3method(summary,SBC_results)
export(SBC_backend_brms)
export(SBC_backend_brms_from_generator)
export(SBC_backend_cmdstan_sample)
export(SBC_backend_cmdstan_variational)
export(SBC_backend_rstan_sample)
export(SBC_datasets)
export(SBC_diagnostic_messages)
Expand Down
Empty file added R/.Rapp.history
Empty file.
42 changes: 42 additions & 0 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,48 @@ SBC_fit_to_diagnostics.CmdStanMCMC <- function(fit, fit_output, fit_messages, fi
res
}

#' Backend based on variational approximation via `cmdstanr`.
#'
#' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`)
#' @param ... other arguments passed to the `$variational()` method of the model. The `data` and
#' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC
#' package.
#' @export
SBC_backend_cmdstan_variational <- function(model, ...) {
stopifnot(inherits(model, "CmdStanModel"))
if(length(model$exe_file()) == 0) {
stop("The model has to be already compiled, call $compile() first.")
}
args <- list(...)
unacceptable_params <- c("data")
if(any(names(args) %in% unacceptable_params)) {
stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "),
" cannot be provided when defining a backend as they need to be set ",
"by the SBC package"))
}
structure(list(model = model, args = args), class = "SBC_backend_cmdstan_variational")
}

#' @export
SBC_fit.SBC_backend_cmdstan_variational <- function(backend, generated, cores) {
fit <- do.call(backend$model$variational,
combine_args(backend$args,
list(
data = generated)))

if(all(fit$return_codes() != 0)) {
stop("Variational inference did not finish succesfully")
}

fit
}

#' @export
SBC_fit_to_draws_matrix.CmdStanVB <- function(fit) {
fit$draws(format = "draws_matrix")

}

# For internal use, creates brms backend.
new_SBC_backend_brms <- function(compiled_model,
args
Expand Down
12 changes: 6 additions & 6 deletions R/results.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ length.SBC_results <- function(x) {
#'
#' Parallel processing is supported via the `future` package, for most uses, it is most sensible
#' to just call `plan(multisession)` once in your R session and all
#' cores your computer has will be used. For more details refer to the documentation
#' cores your computer will be used. For more details refer to the documentation
#' of the `future` package.
#'
#' @param datasets an object of class `SBC_datasets`
#' @param backend the model + sampling algorithm. The built-in backends can be constructed
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
#' (more to come). The backend is an S3 class supporting at least the `SBC_fit`,
#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_cmdstan_variational()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`.
#' (more to come: issue 31, 38, 39). The backend is an S3 class supporting at least the `SBC_fit`,
#' `SBC_fit_to_draws_matrix` methods.
#' @param cores_per_fit how many cores should the backend be allowed to use for a single fit?
#' Defaults to the maximum number that does not produce more parallel chains
Expand Down Expand Up @@ -271,7 +271,6 @@ compute_results <- function(datasets, backend,
generated = datasets$generated[[i]]
)
}

if(is.null(gen_quants)) {
future.globals <- FALSE
} else {
Expand All @@ -288,7 +287,6 @@ compute_results <- function(datasets, backend,
future.globals = future.globals,
future.chunk.size = chunk_size)


# Combine, check and summarise
fits <- rep(list(NULL), length(datasets))
outputs <- rep(list(NULL), length(datasets))
Expand All @@ -307,7 +305,9 @@ compute_results <- function(datasets, backend,
stats_list[[i]] <- results_raw[[i]]$stats
stats_list[[i]]$dataset_id <- i
backend_diagnostics_list[[i]] <- results_raw[[i]]$backend_diagnostics
backend_diagnostics_list[[i]]$dataset_id <- i
if(!is.null(results_raw[[i]]$backend_diagnostics)){
backend_diagnostics_list[[i]]$dataset_id <- i
}
}
else {
if(n_errors < max_errors_to_show) {
Expand Down
18 changes: 18 additions & 0 deletions man/SBC_backend_cmdstan_variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/compute_results.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.