-
Notifications
You must be signed in to change notification settings - Fork 5
Support for variational, (optimize) #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -267,6 +267,87 @@ SBC_fit_to_diagnostics.CmdStanMCMC <- function(fit, fit_output, fit_messages, fi | |
| res | ||
| } | ||
|
|
||
| #' Backend based on variational approximation via `cmdstanr`. | ||
| #' | ||
| #' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`) | ||
| #' @param ... other arguments passed to the `$variational()` method of the model. The `data` and | ||
| #' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC | ||
| #' package. | ||
| #' @export | ||
| SBC_backend_cmdstan_variational <- function(model, ...) { | ||
| stopifnot(inherits(model, "CmdStanModel")) | ||
| if(length(model$exe_file()) == 0) { | ||
| stop("The model has to be already compiled, call $compile() first.") | ||
| } | ||
| args <- list(...) | ||
| unacceptable_params <- c("data", "parallel_chains ", "cores", "num_cores") | ||
| if(any(names(args) %in% unacceptable_params)) { | ||
| stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "), | ||
| " cannot be provided when defining a backend as they need to be set ", | ||
| "by the SBC package")) | ||
| } | ||
| structure(list(model = model, args = args), class = "SBC_backend_cmdstan_variational") | ||
| } | ||
|
|
||
| #' @export | ||
| SBC_fit.SBC_backend_cmdstan_variational <- function(backend, generated, cores) { | ||
| fit <- do.call(backend$model$variational, | ||
| combine_args(backend$args, | ||
| list( | ||
| data = generated))) | ||
|
|
||
| if(all(fit$return_codes() != 0)) { | ||
| stop("No chains finished succesfully") | ||
Dashadower marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| fit | ||
| } | ||
|
|
||
| #' @export | ||
| SBC_fit_to_draws_matrix.CmdStanVB <- function(fit) { | ||
| fit$draws(format = "draws_matrix") | ||
|
|
||
| } | ||
|
|
||
| #' Backend based on optimize approximation via `cmdstanr`. | ||
| #' | ||
| #' @param model an object of class `CmdStanModel` (as created by `cmdstanr::cmdstan_model`) | ||
| #' @param ... other arguments passed to the `$optimize()` method of the model. The `data` and | ||
| #' `parallel_chains` arguments cannot be set this way as they need to be controlled by the SBC | ||
| #' package. | ||
| #' @export | ||
| SBC_backend_cmdstan_optimize <- function(model, ...) { | ||
| stop("The optimize method is currently not supported.") | ||
|
||
| stopifnot(inherits(model, "CmdStanModel")) | ||
| if(length(model$exe_file()) == 0) { | ||
| stop("The model has to be already compiled, call $compile() first.") | ||
| } | ||
| args <- list(...) | ||
| unacceptable_params <- c("data") | ||
| if(any(names(args) %in% unacceptable_params)) { | ||
| stop(paste0("Parameters ", paste0("'", unacceptable_params, "'", collapse = ", "), | ||
| " cannot be provided when defining a backend as they need to be set ", | ||
| "by the SBC package")) | ||
| } | ||
| structure(list(model = model, args = args), class = "SBC_backend_cmdstan_optimize") | ||
| } | ||
| #' @export | ||
| SBC_fit.SBC_backend_cmdstan_optimize <- function(backend, generated, cores) { | ||
| fit <- do.call(backend$model$optimize, | ||
| combine_args(backend$args, | ||
| list(data = generated))) | ||
|
|
||
| if(all(fit$return_codes() != 0)) { | ||
| stop("Point optimization failed!") | ||
| } | ||
|
|
||
| fit | ||
| } | ||
| #' @export | ||
| SBC_fit_to_draws_matrix.CmdStanMLE <- function(fit) { | ||
| fit$draws(format = "draws_matrix") | ||
| } | ||
|
|
||
| # For internal use, creates brms backend. | ||
| new_SBC_backend_brms <- function(compiled_model, | ||
| args | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
variationalis always run on single core, I guessdatais the only argument we really want to forbid here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a
threadargument (andthread_per_chainargument forsample). Is this argument not relative to SBC? What is the difference using four parallel chains each with single threads vs one four threads for a single chain? This reduce_sum doc introduces thread as similar to parallel.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Within-chain paralellization is another big can of worms here. I think it is sufficiently rare to let us expect people to just figure out the correct configuration (number of workers,
cores_per_fit) themselves and pass the correct threading-related arguments to the backend - the configuration would likely would be very much use-case dependent, although most often NOT using any within-chain paralellization would be the best choice. We however definitely need to document this. I've started #49 to make sure we don't forget.