diff --git a/NAMESPACE b/NAMESPACE index c0b5cae..6e6d874 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,12 +4,14 @@ S3method("[",SBC_datasets) S3method("[",SBC_results) S3method(SBC_fit,SBC_backend_brms) S3method(SBC_fit,SBC_backend_cmdstan_sample) +S3method(SBC_fit,SBC_backend_cmdstan_variational) S3method(SBC_fit,SBC_backend_rstan_sample) S3method(SBC_fit_to_diagnostics,CmdStanMCMC) S3method(SBC_fit_to_diagnostics,brmsfit) S3method(SBC_fit_to_diagnostics,default) S3method(SBC_fit_to_diagnostics,stanfit) S3method(SBC_fit_to_draws_matrix,CmdStanMCMC) +S3method(SBC_fit_to_draws_matrix,CmdStanVB) S3method(SBC_fit_to_draws_matrix,brmsfit) S3method(SBC_fit_to_draws_matrix,default) S3method(check_all_SBC_diagnostics,SBC_results) @@ -34,6 +36,7 @@ S3method(summary,SBC_results) export(SBC_backend_brms) export(SBC_backend_brms_from_generator) export(SBC_backend_cmdstan_sample) +export(SBC_backend_cmdstan_variational) export(SBC_backend_rstan_sample) export(SBC_datasets) export(SBC_diagnostic_messages) diff --git a/R/.Rapp.history b/R/.Rapp.history new file mode 100644 index 0000000..e69de29 diff --git a/R/backends.R b/R/backends.R index f2d113a..7366dbb 100644 --- a/R/backends.R +++ b/R/backends.R @@ -267,6 +267,48 @@ SBC_fit_to_diagnostics.CmdStanMCMC <- function(fit, fit_output, fit_messages, fi res } +#' Backend based on variational approximation via `cmdstanr`. +#' +#' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`) +#' @param ... other arguments passed to the `$variational()` method of the model. The `data` and +#' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC +#' package. +#' @export +SBC_backend_cmdstan_variational <- function(model, ...) { + stopifnot(inherits(model, "CmdStanModel")) + if(length(model$exe_file()) == 0) { + stop("The model has to be already compiled, call $compile() first.") + } + args <- list(...) + unacceptable_params <- c("data") + if(any(names(args) %in% unacceptable_params)) { + stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "), + " cannot be provided when defining a backend as they need to be set ", + "by the SBC package")) + } + structure(list(model = model, args = args), class = "SBC_backend_cmdstan_variational") +} + +#' @export +SBC_fit.SBC_backend_cmdstan_variational <- function(backend, generated, cores) { + fit <- do.call(backend$model$variational, + combine_args(backend$args, + list( + data = generated))) + + if(all(fit$return_codes() != 0)) { + stop("Variational inference did not finish succesfully") + } + + fit +} + +#' @export +SBC_fit_to_draws_matrix.CmdStanVB <- function(fit) { + fit$draws(format = "draws_matrix") + +} + # For internal use, creates brms backend. new_SBC_backend_brms <- function(compiled_model, args diff --git a/R/results.R b/R/results.R index 2d94a68..385121b 100644 --- a/R/results.R +++ b/R/results.R @@ -212,13 +212,13 @@ length.SBC_results <- function(x) { #' #' Parallel processing is supported via the `future` package, for most uses, it is most sensible #' to just call `plan(multisession)` once in your R session and all -#' cores your computer has will be used. For more details refer to the documentation +#' cores your computer will be used. For more details refer to the documentation #' of the `future` package. #' #' @param datasets an object of class `SBC_datasets` #' @param backend the model + sampling algorithm. The built-in backends can be constructed -#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`. -#' (more to come). The backend is an S3 class supporting at least the `SBC_fit`, +#' using `SBC_backend_cmdstan_sample()`, `SBC_backend_cmdstan_variational()`, `SBC_backend_rstan_sample()` and `SBC_backend_brms()`. +#' (more to come: issue 31, 38, 39). The backend is an S3 class supporting at least the `SBC_fit`, #' `SBC_fit_to_draws_matrix` methods. #' @param cores_per_fit how many cores should the backend be allowed to use for a single fit? #' Defaults to the maximum number that does not produce more parallel chains @@ -271,7 +271,6 @@ compute_results <- function(datasets, backend, generated = datasets$generated[[i]] ) } - if(is.null(gen_quants)) { future.globals <- FALSE } else { @@ -288,7 +287,6 @@ compute_results <- function(datasets, backend, future.globals = future.globals, future.chunk.size = chunk_size) - # Combine, check and summarise fits <- rep(list(NULL), length(datasets)) outputs <- rep(list(NULL), length(datasets)) @@ -307,7 +305,9 @@ compute_results <- function(datasets, backend, stats_list[[i]] <- results_raw[[i]]$stats stats_list[[i]]$dataset_id <- i backend_diagnostics_list[[i]] <- results_raw[[i]]$backend_diagnostics - backend_diagnostics_list[[i]]$dataset_id <- i + if(!is.null(results_raw[[i]]$backend_diagnostics)){ + backend_diagnostics_list[[i]]$dataset_id <- i + } } else { if(n_errors < max_errors_to_show) { diff --git a/man/SBC_backend_cmdstan_variational.Rd b/man/SBC_backend_cmdstan_variational.Rd new file mode 100644 index 0000000..ad07463 --- /dev/null +++ b/man/SBC_backend_cmdstan_variational.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/backends.R +\name{SBC_backend_cmdstan_variational} +\alias{SBC_backend_cmdstan_variational} +\title{Backend based on variational approximation via \code{cmdstanr}.} +\usage{ +SBC_backend_cmdstan_variational(model, ...) +} +\arguments{ +\item{model}{an object of class \code{CmdStanModel} (as created by \code{cmdstanr::cmdstan_model})} + +\item{...}{other arguments passed to the \verb{$variational()} method of the model. The \code{data} and +\code{parallel_chains} arguments cannot be set this way as they need to be controlled by the SBC +package.} +} +\description{ +Backend based on variational approximation via \code{cmdstanr}. +} diff --git a/man/compute_results.Rd b/man/compute_results.Rd index af6a5ca..18cffab 100644 --- a/man/compute_results.Rd +++ b/man/compute_results.Rd @@ -18,8 +18,8 @@ compute_results( \item{datasets}{an object of class \code{SBC_datasets}} \item{backend}{the model + sampling algorithm. The built-in backends can be constructed -using \code{SBC_backend_cmdstan_sample()}, \code{SBC_backend_rstan_sample()} and \code{SBC_backend_brms()}. -(more to come). The backend is an S3 class supporting at least the \code{SBC_fit}, +using \code{SBC_backend_cmdstan_sample()}, \code{SBC_backend_cmdstan_variational()}, \code{SBC_backend_rstan_sample()} and \code{SBC_backend_brms()}. +(more to come: issue 31, 38, 39). The backend is an S3 class supporting at least the \code{SBC_fit}, \code{SBC_fit_to_draws_matrix} methods.} \item{cores_per_fit}{how many cores should the backend be allowed to use for a single fit? @@ -44,7 +44,7 @@ the work may be distributed less equally across workers. We recommend setting th enough that a single batch takes at least several seconds, i.e. for small models, you can often reduce computation time noticeably by increasing this value. You can use \code{options(SBC.min_chunk_size = value)} to set a minimum chunk size globally. -See documentation of \code{future.chunk.size} argument for \code{future_lapply()} for more details.} +See documentation of \code{future.chunk.size} argument for \code{future.apply::future_lapply()} for more details.} } \value{ An object of class \code{SBC_results} that holds: @@ -60,6 +60,6 @@ An object of class \code{SBC_results} that holds: \description{ Parallel processing is supported via the \code{future} package, for most uses, it is most sensible to just call \code{plan(multisession)} once in your R session and all -cores your computer has will be used. For more details refer to the documentation +cores your computer will be used. For more details refer to the documentation of the \code{future} package. }