diff --git a/DESCRIPTION b/DESCRIPTION index d7489579..c18ec24b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,6 +29,7 @@ Imports: utils, rlang (>= 0.3.0), ggridges, + tibble, hexbin Suggests: gridExtra (>= 2.2.1), diff --git a/NAMESPACE b/NAMESPACE index 0436c4cb..747af0d6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -70,6 +70,8 @@ export(mcmc_nuts_treedepth) export(mcmc_pairs) export(mcmc_parcoord) export(mcmc_parcoord_data) +export(mcmc_rank_hist) +export(mcmc_rank_overlay) export(mcmc_recover_hist) export(mcmc_recover_intervals) export(mcmc_recover_scatter) @@ -78,6 +80,7 @@ export(mcmc_rhat_data) export(mcmc_rhat_hist) export(mcmc_scatter) export(mcmc_trace) +export(mcmc_trace_data) export(mcmc_trace_highlight) export(mcmc_violin) export(neff_ratio) diff --git a/NEWS.md b/NEWS.md index 6472863a..5d3860ae 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,23 @@ +* Two new plots have been added for inspecting the distribution of ranks. + Rank-normalized histograms were introduced by the Stan team's [new paper on + MCMC diagnostics](https://arxiv.org/abs/1903.08008). (#178, #179) + + `mcmc_rank_hist()`: A traditional traceplot (`mcmc_trace()`) visualizes how + sampled values the MCMC chains mix over the course of sampling. A + rank-normalized histogram (`mcmc_rank_hist()`) visualizes how the *ranks* of + values from the chains mix together. An ideal plot would show the ranks mixing + or overlapping in a uniform distribution. + + `mcmc_rank_overlay()`: Instead of drawing each chain's histogram in a separate + panel, this plot draws the top edge of the chains' histograms in a single + panel. + +* Added `mcmc_trace_data()`, which returns the data used for plotting the trace + plots and rank histograms. (Advances #97) + * [ColorBrewer](http://colorbrewer2.org) palettes are now available as color schemes via [`color_scheme_set()`](https://mc-stan.org/bayesplot/reference/bayesplot-colors.html). @@ -49,6 +66,12 @@ * The examples in [`?ppc_loo_pit_overlay()`](https://mc-stan.org/bayesplot/reference/PPC-loo.html) now work as expected. (#166, #167) + +* Added `"viridisD"` as an alternative name for `"viridis"` to the supported + colors. + +* Added `"viridisE"` (the [cividis](https://github.com/marcosci/cividis) + version of viridis) to the supported colors. * `ppc_bars()` and `ppc_bars_grouped()` now allow negative integers as input. (#172, @jeffpollock9) diff --git a/R/bayesplot-colors.R b/R/bayesplot-colors.R index 7c2bd063..b61eba23 100644 --- a/R/bayesplot-colors.R +++ b/R/bayesplot-colors.R @@ -45,7 +45,7 @@ #' * `"teal"` #' * `"yellow"` #' * [`"viridis"`](https://CRAN.R-project.org/package=viridis), `"viridisA"`, -#' `"viridisB"`, `"viridisC"` +#' `"viridisB"`, `"viridisC"`, `"viridisD"`, `"viridisE"` #' * `"mix-x-y"`, replacing `x` and `y` with any two of #' the scheme names listed above (e.g. "mix-teal-pink", "mix-blue-red", #' etc.). The order of `x` and `y` matters, i.e., the color schemes @@ -395,7 +395,12 @@ master_color_list <- list( viridisB = list("#FCFFA4FF", "#FCA50AFF", "#DD513AFF", "#932667FF", "#420A68FF", "#000004FF"), viridisC = - list("#F0F921FF", "#FCA636FF", "#E16462FF", "#B12A90FF", "#6A00A8FF", "#0D0887FF") + list("#F0F921FF", "#FCA636FF", "#E16462FF", "#B12A90FF", "#6A00A8FF", "#0D0887FF"), + # popular form of viridis is viridis option D + viridisD = + list("#FDE725FF", "#7AD151FF", "#22A884FF", "#2A788EFF", "#414487FF", "#440154FF"), + viridisE = + list("#FFEA46FF", "#CBBA69FF", "#958F78FF", "#666970FF", "#31446BFF", "#00204DFF") ) # instantiate aesthetics -------------------------------------------------- diff --git a/R/bayesplot-package.R b/R/bayesplot-package.R index 9a9f1e40..5c89d6d0 100644 --- a/R/bayesplot-package.R +++ b/R/bayesplot-package.R @@ -5,7 +5,7 @@ #' @aliases bayesplot #' #' @import ggplot2 stats rlang -#' @importFrom dplyr %>% +#' @importFrom dplyr %>% summarise group_by select #' #' @description #' \if{html}{ diff --git a/R/helpers-gg.R b/R/helpers-gg.R index 7254b880..f084b9b4 100644 --- a/R/helpers-gg.R +++ b/R/helpers-gg.R @@ -64,12 +64,26 @@ dont_expand_axes <- function() { } force_axes_in_facets <- function() { thm <- bayesplot_theme_get() - annotate("segment", - x = c(-Inf, -Inf), xend = c(Inf,-Inf), - y = c(-Inf,-Inf), yend = c(-Inf, Inf), - color = thm$axis.line$colour %||% "black", - size = thm$axis.line$size %||% 0.5) + annotate( + "segment", + x = c(-Inf, -Inf), xend = c(Inf,-Inf), + y = c(-Inf,-Inf), yend = c(-Inf, Inf), + color = thm$axis.line$colour %||% thm$line$colour %||% "black", + size = thm$axis.line$size %||% thm$line$size %||% 0.5 + ) } + +force_x_axis_in_facets <- function() { + thm <- bayesplot_theme_get() + annotate( + "segment", + x = -Inf, xend = Inf, + y = -Inf, yend = -Inf, + color = thm$axis.line$colour %||% thm$line$colour %||% "black", + size = thm$axis.line$size %||% thm$line$size %||% 0.5 + ) +} + no_legend_spacing <- function() { theme(legend.spacing.y = unit(0, "cm")) } diff --git a/R/helpers-ppc.R b/R/helpers-ppc.R index f7ad8769..5abebbc6 100644 --- a/R/helpers-ppc.R +++ b/R/helpers-ppc.R @@ -151,7 +151,7 @@ validate_x <- function(x = NULL, y, unique_x = FALSE) { melt_yrep <- function(yrep) { out <- yrep %>% reshape2::melt(varnames = c("rep_id", "y_id")) %>% - dplyr::as_data_frame() + tibble::as_tibble() id <- create_yrep_ids(out$rep_id) out$rep_label <- factor(id, levels = unique(id)) out[c("y_id", "rep_id", "rep_label", "value")] @@ -178,7 +178,7 @@ melt_and_stack <- function(y, yrep) { # Add a level in the labels for the observed y values levels(molten_yrep$rep_label) <- c(levels(molten_yrep$rep_label), y_text) - ydat <- dplyr::data_frame( + ydat <- tibble::tibble( rep_label = factor(y_text, levels = levels(molten_yrep$rep_label)), rep_id = NA_integer_, y_id = seq_along(y), diff --git a/R/mcmc-diagnostics.R b/R/mcmc-diagnostics.R index 750b9115..8ff84fc4 100644 --- a/R/mcmc-diagnostics.R +++ b/R/mcmc-diagnostics.R @@ -379,7 +379,7 @@ diagnostic_data_frame <- function(x) { stopifnot(!anyDuplicated(names(x))) diagnostic <- class(x)[1] - d <- dplyr::data_frame( + d <- tibble::tibble( diagnostic = diagnostic, parameter = factor(seq_along(x), labels = names(x)), value = as.numeric(x), diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index 79f2e54b..462cc5e2 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -1,11 +1,10 @@ -#' Trace plot (time series plot) of MCMC draws +#' Trace plots of MCMC draws #' #' Trace plot (or traceplot) of MCMC draws. See the **Plot Descriptions** #' section, below, for details. #' #' @name MCMC-traces #' @family MCMC -#' #' @template args-mcmc-x #' @template args-pars #' @template args-regex_pars @@ -13,8 +12,7 @@ #' @template args-facet_args #' @param ... Currently ignored. #' @param size An optional value to override the default line size -#' (`mcmc_trace()`) or the default point size -#' (`mcmc_trace_highlight()`). +#' for `mcmc_trace()` or the default point size for `mcmc_trace_highlight()`. #' @param alpha For `mcmc_trace_highlight()`, passed to #' [ggplot2::geom_point()] to control the transparency of the points #' for the chains not highlighted. @@ -40,7 +38,9 @@ #' divergences (if the `np` argument is specified). #' @param divergences Deprecated. Use the `np` argument instead. #' -#' @template return-ggplot +#' @template return-ggplot-or-data +#' @return `mcmc_trace_data()` returns the data for the trace *and* rank plots +#' in the same data frame. #' #' @section Plot Descriptions: #' \describe{ @@ -52,8 +52,20 @@ #' Traces are plotted using points rather than lines and the opacity of all #' chains but one (specified by the `highlight` argument) is reduced. #' } +#' \item{`mcmc_rank_hist()`}{ +#' Whereas traditional trace plots visualize how the chains mix over the +#' course of sampling, rank-normalized histograms visualize how the values +#' from the chains mix together in terms of ranking. An ideal plot would +#' show the rankings mixing or overlapping in a uniform distribution. +#' See Vehtari et al. (2019) for details. +#' } +#' \item{`mcmc_rank_overlay()`}{ +#' Ranks from `mcmc_rank_hist()` are plotted using overlaid lines in a +#' single panel. +#' } #' } #' +#' @template reference-improved-rhat #' @examples #' # some parameter draws to use for demonstration #' x <- example_mcmc_draws(chains = 4, params = 6) @@ -86,6 +98,13 @@ #' panel_bg(fill = "gray90", color = NA) + #' legend_move("top") #' +#' # Rank-normalized histogram plots. Instead of showing how chains mix over +#' # time, look at how the ranking of MCMC samples mixed between chains. +#' color_scheme_set("viridisE") +#' mcmc_rank_hist(x, "alpha") +#' mcmc_rank_hist(x, pars = c("alpha", "sigma"), ref_line = TRUE) +#' mcmc_rank_overlay(x, "alpha") +#' #' \dontrun{ #' # parse facet label text #' color_scheme_set("purple") @@ -126,7 +145,6 @@ #' np_style = trace_style_np(div_color = "black", div_size = 0.5) #' ) #' -#' color_scheme_set("viridis") #' mcmc_trace( #' posterior, #' pars = c("wt", "sigma"), @@ -156,71 +174,71 @@ mcmc_trace <- np_style = trace_style_np(), divergences = NULL) { - # deprecate 'divergences' arg in favor of 'np' (for consistency across functions) - if (!is.null(np) && !is.null(divergences)) { - abort(paste( - "'np' and 'divergences' can't both be specified.", - "Use only 'np' (the 'divergences' argument is deprecated)." - )) - } else if (!is.null(divergences)) { - warn(paste( - "The 'divergences' argument is deprecated", - "and will be removed in a future release.", - "Use the 'np' argument instead." - )) - np <- divergences - } - - check_ignored_arguments(...) - .mcmc_trace( - x, - pars = pars, - regex_pars = regex_pars, - transformations = transformations, - facet_args = facet_args, - n_warmup = n_warmup, - window = window, - size = size, - style = "line", - np = np, - np_style = np_style, - iter1 = iter1, - ... - ) + # deprecate 'divergences' arg in favor of 'np' + # (for consistency across functions) + if (!is.null(np) && !is.null(divergences)) { + abort(paste0( + "'np' and 'divergences' can't both be specified. ", + "Use only 'np' (the 'divergences' argument is deprecated)." + )) + } else if (!is.null(divergences)) { + warn(paste0( + "The 'divergences' argument is deprecated ", + "and will be removed in a future release. ", + "Use the 'np' argument instead." + )) + np <- divergences } + check_ignored_arguments(...) + .mcmc_trace( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations, + facet_args = facet_args, + n_warmup = n_warmup, + window = window, + size = size, + style = "line", + np = np, + np_style = np_style, + iter1 = iter1, + ... + ) +} + #' @rdname MCMC-traces #' @export #' @param highlight For `mcmc_trace_highlight()`, an integer specifying one #' of the chains that will be more visible than the others in the plot. -mcmc_trace_highlight <- - function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - facet_args = list(), - ..., - n_warmup = 0, - window = NULL, - size = NULL, - alpha = 0.2, - highlight = 1) { - check_ignored_arguments(...) - .mcmc_trace( - x, - pars = pars, - regex_pars = regex_pars, - transformations = transformations, - facet_args = facet_args, - n_warmup = n_warmup, - window = window, - size = size, - alpha = alpha, - highlight = highlight, - style = "point", - ... - ) - } +mcmc_trace_highlight <- function(x, + pars = character(), + regex_pars = character(), + transformations = list(), + facet_args = list(), + ..., + n_warmup = 0, + window = NULL, + size = NULL, + alpha = 0.2, + highlight = 1) { + check_ignored_arguments(...) + .mcmc_trace( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations, + facet_args = facet_args, + n_warmup = n_warmup, + window = window, + size = size, + alpha = alpha, + highlight = highlight, + style = "point", + ... + ) +} #' @rdname MCMC-traces @@ -230,42 +248,192 @@ mcmc_trace_highlight <- #' [ggplot2::geom_rug()] if the `np` argument is also specified. They control #' the color, size, and transparency specifications for showing divergences in #' the plot. The default values are displayed in the **Usage** section above. -trace_style_np <- - function(div_color = "red", - div_size = 0.25, - div_alpha = 1) { - stopifnot( - is.character(div_color), - is.numeric(div_size), - is.numeric(div_alpha) && div_alpha >= 0 && div_alpha <= 1 +#' +trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) { + stopifnot( + is.character(div_color), + is.numeric(div_size), + is.numeric(div_alpha) && div_alpha >= 0 && div_alpha <= 1 + ) + + style <- list( + color = c(div = div_color), + size = c(div = div_size), + alpha = c(div = div_alpha) + ) + + structure(style, class = c(class(style), "nuts_style")) +} + +#' @rdname MCMC-traces +#' @param n_bins For the rank plots, the number of bins to use for the histogram +#' of rank-normalized MCMC samples. Defaults to `20`. +#' @param ref_line For the rank plots, whether to draw a horizontal line at the +#' average number of ranks per bin. Defaults to `FALSE`. +#' @export +mcmc_rank_overlay <- function(x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + n_bins = 20, + ref_line = FALSE) { + check_ignored_arguments(...) + data <- mcmc_trace_data( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations + ) + + n_chains <- unique(data$n_chains) + + # We have to bin and count the data ourselves because + # ggplot2::stat_bin(geom = "step") does not draw the final bin. + histobins <- data %>% + dplyr::distinct(.data$value_rank) %>% + mutate(cut = cut(.data$value_rank, n_bins)) %>% + group_by(.data$cut) %>% + mutate(bin_start = min(.data$value_rank)) %>% + ungroup() %>% + select(-.data$cut) + + d_bin_counts <- data %>% + left_join(histobins, by = "value_rank") %>% + count(.data$parameter, .data$chain, .data$bin_start) + + # Duplicate the final bin, setting the left edge to the greatest x value, so + # that the entire x-axis is used, + right_edge <- max(data$value_rank) + + d_bin_counts <- d_bin_counts %>% + dplyr::filter(.data$bin_start == max(.data$bin_start)) %>% + mutate(bin_start = right_edge) %>% + dplyr::bind_rows(d_bin_counts) + + scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains)) + + layer_ref_line <- if (ref_line) { + geom_hline( + yintercept = (right_edge / n_bins) / n_chains, + color = get_color("dark_highlight"), + size = 1, + linetype = "dashed" ) - style <- list( - color = c(div = div_color), - size = c(div = div_size), - alpha = c(div = div_alpha) + } else { + NULL + } + + ggplot(d_bin_counts) + + aes_(x = ~ bin_start, y = ~ n, color = ~ chain) + + geom_step() + + layer_ref_line + + facet_wrap("parameter") + + scale_color + + ylim(c(0, NA)) + + bayesplot_theme_get() + + force_x_axis_in_facets() + + labs(x = "Rank", y = NULL) +} + +#' @rdname MCMC-traces +#' @export +mcmc_rank_hist <- function(x, + pars = character(), + regex_pars = character(), + transformations = list(), + facet_args = list(), + ..., + n_bins = 20, + ref_line = FALSE) { + check_ignored_arguments(...) + data <- mcmc_trace_data( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations + ) + + n_iter <- unique(data$n_iterations) + n_chains <- unique(data$n_chains) + n_param <- unique(data$n_parameters) + + # Create a dataframe with chain x parameter x min(rank) x max(rank) to set + # x axis range in each facet + data_boundaries <- data %>% + dplyr::distinct(.data$chain, .data$parameter) + + data_boundaries <- dplyr::bind_rows( + mutate(data_boundaries, value_rank = min(data$value_rank)), + mutate(data_boundaries, value_rank = max(data$value_rank)) + ) + + right_edge <- max(data_boundaries$value_rank) + + facet_args[["scales"]] <- facet_args[["scales"]] %||% "fixed" + facet_args[["facets"]] <- facet_args[["facets"]] %||% (parameter ~ chain) + + # If there is one parameter, put the chains in one row. + # Otherwise, use a grid. + if (n_param > 1) { + facet_f <- facet_grid + } else { + facet_f <- facet_wrap + facet_args[["nrow"]] <- facet_args[["nrow"]] %||% 1 + labeller <- function(x) label_value(x, multi_line = FALSE) + facet_args[["labeller"]] <- facet_args[["labeller"]] %||% labeller + } + + layer_ref_line <- if (ref_line) { + geom_hline( + yintercept = (right_edge / n_bins) / n_chains, + color = get_color("dark_highlight"), + size = .5, + linetype = "dashed" ) - structure(style, class = c(class(style), "nuts_style")) + } else { + NULL } + facet_call <- do.call(facet_f, facet_args) + + ggplot(data) + + aes_(x = ~ value_rank) + + geom_histogram( + color = get_color("mid_highlight"), + fill = get_color("mid"), + binwidth = right_edge / n_bins, + boundary = right_edge, + size = .25 + ) + + layer_ref_line + + geom_blank(data = data_boundaries) + + facet_call + + force_x_axis_in_facets() + + dont_expand_y_axis(c(0.005, 0)) + + bayesplot_theme_get() + + theme( + axis.line.y = element_blank(), + axis.title.y = element_blank(), + axis.text.y = element_blank(), + axis.ticks = element_blank() + ) + + labs(x = "Rank") +} -# internal ----------------------------------------------------------------- -.mcmc_trace <- function(x, - pars = character(), - regex_pars = character(), - transformations = list(), - n_warmup = 0, - window = NULL, - size = NULL, - facet_args = list(), - highlight = NULL, - style = c("line", "point"), - alpha = 0.2, - np = NULL, - np_style = trace_style_np(), - iter1 = 0, - ...) { - style <- match.arg(style) +#' @rdname MCMC-traces +#' @export +mcmc_trace_data <- function(x, + pars = character(), + regex_pars = character(), + transformations = list(), + ..., + highlight = NULL, + n_warmup = 0, + iter1 = 0) { + check_ignored_arguments(...) + x <- prepare_mcmc_array(x, pars, regex_pars, transformations) if (iter1 < 0) { @@ -277,7 +445,9 @@ trace_style_np <- } if (!is.null(highlight)) { - if (!has_multiple_chains(x)) { + stopifnot(length(highlight) == 1) + + if (!has_multiple_chains(x)){ STOP_need_multiple_chains() } @@ -291,79 +461,141 @@ trace_style_np <- data <- melt_mcmc(x) data$Chain <- factor(data$Chain) - n_chain <- num_chains(data) - n_iter <- num_iters(data) - n_param <- num_params(data) + data$n_chains <- num_chains(data) + data$n_iterations <- num_iters(data) + data$n_parameters <- num_params(data) + data <- rlang::set_names(data, tolower) + + first_cols <- syms(c("parameter", "value", "value_rank")) + data <- data %>% + group_by(.data$parameter) %>% + mutate(value_rank = dplyr::row_number(.data$value)) %>% + ungroup() %>% + select(!!! first_cols, dplyr::everything()) + + data$highlight <- if (!is.null(highlight)) { + data$chain == highlight + } else { + FALSE + } - geom_args <- list() - geom_args$size <- size %||% ifelse(style == "line", 1/3, 1) + data$warmup <- data$iteration <= n_warmup + data$iteration <- data$iteration + as.integer(iter1) + tibble::as_tibble(data) +} - if (is.null(highlight)) { - mapping <- aes_(x = ~ Iteration + iter1, y = ~ Value, color = ~ Chain) - } else { - stopifnot(length(highlight) == 1) - mapping <- aes_(x = ~ Iteration + iter1, - y = ~ Value, - alpha = ~ Chain == highlight, - color = ~ Chain == highlight) + +# internal ----------------------------------------------------------------- +.mcmc_trace <- function(x, + pars = character(), + regex_pars = character(), + transformations = list(), + n_warmup = 0, + window = NULL, + size = NULL, + facet_args = list(), + highlight = NULL, + style = c("line", "point"), + alpha = 0.2, + np = NULL, + np_style = trace_style_np(), + iter1 = 0, + ...) { + style <- match.arg(style) + data <- mcmc_trace_data( + x, + pars = pars, + regex_pars = regex_pars, + transformations = transformations, + highlight = highlight, + n_warmup = n_warmup, + iter1 = iter1 + ) + n_iter <- unique(data$n_iterations) + n_chain <- unique(data$n_chains) + n_param <- unique(data$n_parameters) + + mapping <- aes_( + x = ~ iteration, + y = ~ value, + color = ~ chain + ) + + if (!is.null(highlight)) { + mapping <- modify_aes_( + mapping, + alpha = ~ highlight, + color = ~ highlight + ) } - graph <- ggplot(data, mapping) + - bayesplot_theme_get() - - if (n_warmup > 0) { - graph <- graph + - annotate("rect", - xmin = -Inf, xmax = n_warmup, - ymin = -Inf, ymax = Inf, - size = 1, - color = "gray88", - fill = "gray88", - alpha = 0.5) + + layer_warmup <- if (n_warmup > 0) { + layer_warmup <- annotate( + "rect", xmin = -Inf, xmax = n_warmup, ymin = -Inf, ymax = Inf, size = 1, + color = "gray88", fill = "gray88", alpha = 0.5 + ) + } else { + NULL } - if (!is.null(window)) { + geom_args <- list() + geom_args$size <- size %||% ifelse(style == "line", 1/3, 1) + layer_draws <- do.call(paste0("geom_", style), geom_args) + + coord_window <- if (!is.null(window)) { stopifnot(length(window) == 2) - graph <- graph + coord_cartesian(xlim = window) + coord_cartesian(xlim = window) + } else { + NULL } - graph <- graph + do.call(paste0("geom_", style), geom_args) + scale_alpha <- NULL + scale_color <- NULL + div_rug <- NULL + div_guides <- NULL if (!is.null(highlight)) { - graph <- graph + - scale_alpha_discrete(range = c(alpha, 1), guide = "none") + - scale_color_manual("", - values = get_color(c("lh", "d")), - labels = c("Other chains", paste("Chain", highlight))) + ## scale_alpha_discrete() warns on default + scale_alpha <- scale_alpha_ordinal(range = c(alpha, 1), guide = "none") + scale_color <- scale_color_manual( + "", + values = get_color(c("lh", "d")), + labels = c("Other chains", paste("Chain", highlight))) } else { - graph <- graph + - scale_color_manual("Chain", values = chain_colors(n_chain)) + scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain)) if (!is.null(np)) { div_rug <- divergence_rug(np, np_style, n_iter, n_chain) - if (!is.null(div_rug)) - graph <- graph + - div_rug + - guides( - color = guide_legend(order = 1), - linetype = guide_legend(order = 2, - title = NULL, - keywidth = rel(1/2), - override.aes = list(size = rel(1/2))) - ) + if (!is.null(div_rug)) { + div_guides <- guides( + color = guide_legend(order = 1), + linetype = guide_legend( + order = 2, title = NULL, keywidth = rel(1/2), + override.aes = list(size = rel(1/2))) + ) + } } } - + facet_call <- NULL if (n_param == 1) { - graph <- graph + ylab(levels(data$Parameter)) + facet_call <- ylab(levels(data$parameter)) } else { - facet_args$facets <- ~ Parameter - if (is.null(facet_args$scales)) - facet_args$scales <- "free" - graph <- graph + do.call("facet_wrap", facet_args) + facet_args$facets <- ~ parameter + facet_args$scales <- facet_args$scales %||% "free" + facet_call <- do.call("facet_wrap", facet_args) } - graph + + ggplot(data, mapping) + + bayesplot_theme_get() + + layer_warmup + + layer_draws + + coord_window + + scale_alpha + + scale_color + + div_rug + + div_guides + + facet_call + scale_x_continuous(breaks = pretty) + legend_move(ifelse(n_chain > 1, "right", "none")) + xaxis_title(FALSE) + @@ -394,8 +626,8 @@ chain_colors <- function(n) { #' @param np_style User's `np_style` argument, if specified. #' @param n_iter Number of iterations in the trace plot (to check against number #' of iterations provided in `np`). -#' @param n_chain Number of chains in the trace plot (to check against number -#' of chains provided in `np`). +#' @param n_chain Number of chains in the trace plot (to check against number of +#' chains provided in `np`). #' @return Object returned by `ggplot2::geom_rug()`. #' #' @importFrom dplyr summarise group_by select diff --git a/R/ppc-distributions.R b/R/ppc-distributions.R index 031969bf..92dee239 100644 --- a/R/ppc-distributions.R +++ b/R/ppc-distributions.R @@ -109,7 +109,7 @@ ppc_data <- function(y, yrep, group = NULL) { if (!is.null(group)) { group <- validate_group(group, y) - group_indices <- dplyr::data_frame(group, y_id = seq_along(group)) + group_indices <- tibble::tibble(group, y_id = seq_along(group)) data <- data %>% left_join(group_indices, by = "y_id") %>% select(.data$group, dplyr::everything()) diff --git a/man-roxygen/reference-improved-rhat.R b/man-roxygen/reference-improved-rhat.R new file mode 100644 index 00000000..59da344c --- /dev/null +++ b/man-roxygen/reference-improved-rhat.R @@ -0,0 +1,4 @@ +#' @references Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., Bürkner, P. +#' (2019). Rank-normalization, folding, and localization: An improved *R*-hat +#' for assessing convergence of MCMC. [arXiv +#' preprint](https://arxiv.org/abs/1903.08008). diff --git a/man-roxygen/return-ggplot-or-data.R b/man-roxygen/return-ggplot-or-data.R index 486ddd54..17ab8593 100644 --- a/man-roxygen/return-ggplot-or-data.R +++ b/man-roxygen/return-ggplot-or-data.R @@ -1,3 +1,4 @@ -#' @return A ggplot object that can be further customized using the **ggplot2** -#' package. The functions with suffix `_data` return the data that would have -#' been drawn by the plotting function. +#' @return The plotting functions return a ggplot object that can be further +#' customized using the **ggplot2** package. The functions with suffix +#' `_data()` return the data that would have been drawn by the plotting +#' function. diff --git a/man/MCMC-diagnostics.Rd b/man/MCMC-diagnostics.Rd index 4beb768a..ae8a6db9 100644 --- a/man/MCMC-diagnostics.Rd +++ b/man/MCMC-diagnostics.Rd @@ -69,9 +69,10 @@ to control faceting.} \item{lags}{The number of lags to show in the autocorrelation plot.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} -package. The functions with suffix \code{_data} return the data that would have -been drawn by the plotting function. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. } \description{ Plots of Rhat statistics, ratios of effective sample size to total sample diff --git a/man/MCMC-intervals.Rd b/man/MCMC-intervals.Rd index ab211a2c..e1870c4f 100644 --- a/man/MCMC-intervals.Rd +++ b/man/MCMC-intervals.Rd @@ -103,9 +103,10 @@ points across the curves are the same height. The method \code{"scaled height"} parameters. \code{n_dens} defaults to \code{1024}.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} -package. The functions with suffix \code{_data} return the data that would have -been drawn by the plotting function. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. } \description{ Plot central (quantile-based) posterior interval estimates from MCMC draws. diff --git a/man/MCMC-parcoord.Rd b/man/MCMC-parcoord.Rd index e29546fb..043836a6 100644 --- a/man/MCMC-parcoord.Rd +++ b/man/MCMC-parcoord.Rd @@ -75,9 +75,10 @@ the color, size, and transparency specifications for showing divergences in the plot. The default values are displayed in the \strong{Usage} section above.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} -package. The functions with suffix \code{_data} return the data that would have -been drawn by the plotting function. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. } \description{ Parallel coordinates plot of MCMC draws (one dimension per parameter). diff --git a/man/MCMC-traces.Rd b/man/MCMC-traces.Rd index c15c8f25..7437c734 100644 --- a/man/MCMC-traces.Rd +++ b/man/MCMC-traces.Rd @@ -5,7 +5,10 @@ \alias{mcmc_trace} \alias{mcmc_trace_highlight} \alias{trace_style_np} -\title{Trace plot (time series plot) of MCMC draws} +\alias{mcmc_rank_overlay} +\alias{mcmc_rank_hist} +\alias{mcmc_trace_data} +\title{Trace plots of MCMC draws} \usage{ mcmc_trace(x, pars = character(), regex_pars = character(), transformations = list(), facet_args = list(), ..., n_warmup = 0, @@ -17,6 +20,17 @@ mcmc_trace_highlight(x, pars = character(), regex_pars = character(), window = NULL, size = NULL, alpha = 0.2, highlight = 1) trace_style_np(div_color = "red", div_size = 0.25, div_alpha = 1) + +mcmc_rank_overlay(x, pars = character(), regex_pars = character(), + transformations = list(), ..., n_bins = 20, ref_line = FALSE) + +mcmc_rank_hist(x, pars = character(), regex_pars = character(), + transformations = list(), facet_args = list(), ..., n_bins = 20, + ref_line = FALSE) + +mcmc_trace_data(x, pars = character(), regex_pars = character(), + transformations = list(), ..., highlight = NULL, n_warmup = 0, + iter1 = 0) } \arguments{ \item{x}{A 3-D array, matrix, list of matrices, or data frame of MCMC draws. @@ -75,8 +89,7 @@ if \code{n_warmup} is also set to a positive value.} range of iterations to display.} \item{size}{An optional value to override the default line size -(\code{mcmc_trace()}) or the default point size -(\code{mcmc_trace_highlight()}).} +for \code{mcmc_trace()} or the default point size for \code{mcmc_trace_highlight()}.} \item{np}{For models fit using \link{NUTS} (more generally, any \href{https://en.wikipedia.org/wiki/Symplectic_integrator}{symplectic integrator}), @@ -104,9 +117,21 @@ of the chains that will be more visible than the others in the plot.} \code{\link[ggplot2:geom_rug]{ggplot2::geom_rug()}} if the \code{np} argument is also specified. They control the color, size, and transparency specifications for showing divergences in the plot. The default values are displayed in the \strong{Usage} section above.} + +\item{n_bins}{For the rank plots, the number of bins to use for the histogram +of rank-normalized MCMC samples. Defaults to \code{20}.} + +\item{ref_line}{For the rank plots, whether to draw a horizontal line at the +average number of ranks per bin. Defaults to \code{FALSE}.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} package. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. + +\code{mcmc_trace_data()} returns the data for the trace \emph{and} rank plots +in the same data frame. } \description{ Trace plot (or traceplot) of MCMC draws. See the \strong{Plot Descriptions} @@ -123,6 +148,17 @@ the \code{np} argument can be used to also show divergences on the trace plot. Traces are plotted using points rather than lines and the opacity of all chains but one (specified by the \code{highlight} argument) is reduced. } +\item{\code{mcmc_rank_hist()}}{ +Whereas traditional trace plots visualize how the chains mix over the +course of sampling, rank-normalized histograms visualize how the values +from the chains mix together in terms of ranking. An ideal plot would +show the rankings mixing or overlapping in a uniform distribution. +See Vehtari et al. (2019) for details. +} +\item{\code{mcmc_rank_overlay()}}{ +Ranks from \code{mcmc_rank_hist()} are plotted using overlaid lines in a +single panel. +} } } @@ -158,6 +194,13 @@ mcmc_trace(x[,, 1:4], window = c(100, 130), size = 1) + panel_bg(fill = "gray90", color = NA) + legend_move("top") +# Rank-normalized histogram plots. Instead of showing how chains mix over +# time, look at how the ranking of MCMC samples mixed between chains. +color_scheme_set("viridisE") +mcmc_rank_hist(x, "alpha") +mcmc_rank_hist(x, pars = c("alpha", "sigma"), ref_line = TRUE) +mcmc_rank_overlay(x, "alpha") + \dontrun{ # parse facet label text color_scheme_set("purple") @@ -198,7 +241,6 @@ mcmc_trace( np_style = trace_style_np(div_color = "black", div_size = 0.5) ) -color_scheme_set("viridis") mcmc_trace( posterior, pars = c("wt", "sigma"), @@ -209,6 +251,11 @@ mcmc_trace( ) } +} +\references{ +Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., Bürkner, P. +(2019). Rank-normalization, folding, and localization: An improved \emph{R}-hat +for assessing convergence of MCMC. \href{https://arxiv.org/abs/1903.08008}{arXiv preprint}. } \seealso{ Other MCMC: \code{\link{MCMC-combos}}, diff --git a/man/PPC-distributions.Rd b/man/PPC-distributions.Rd index b0d765a0..902d557b 100644 --- a/man/PPC-distributions.Rd +++ b/man/PPC-distributions.Rd @@ -99,9 +99,10 @@ to control the appearance of \code{y} points. The default of \code{y_jitter=NULL will let \strong{ggplot2} determine the amount of jitter.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} -package. The functions with suffix \code{_data} return the data that would have -been drawn by the plotting function. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. } \description{ Compare the empirical distribution of the data \code{y} to the distributions diff --git a/man/PPC-intervals.Rd b/man/PPC-intervals.Rd index e7c44eef..5ab62ffb 100644 --- a/man/PPC-intervals.Rd +++ b/man/PPC-intervals.Rd @@ -61,9 +61,10 @@ and \code{size} are passed to \code{\link[ggplot2:geom_ribbon]{ggplot2::geom_rib \code{size} and \code{fatten} are passed to \code{\link[ggplot2:geom_pointrange]{ggplot2::geom_pointrange()}}.} } \value{ -A ggplot object that can be further customized using the \strong{ggplot2} -package. The functions with suffix \code{_data} return the data that would have -been drawn by the plotting function. +The plotting functions return a ggplot object that can be further +customized using the \strong{ggplot2} package. The functions with suffix +\code{_data()} return the data that would have been drawn by the plotting +function. } \description{ Medians and central interval estimates of \code{yrep} with \code{y} overlaid. diff --git a/man/bayesplot-colors.Rd b/man/bayesplot-colors.Rd index e8cbc9b2..00bae882 100644 --- a/man/bayesplot-colors.Rd +++ b/man/bayesplot-colors.Rd @@ -67,7 +67,7 @@ schemes are: \item \code{"teal"} \item \code{"yellow"} \item \href{https://CRAN.R-project.org/package=viridis}{"viridis"}, \code{"viridisA"}, -\code{"viridisB"}, \code{"viridisC"} +\code{"viridisB"}, \code{"viridisC"}, \code{"viridisD"}, \code{"viridisE"} \item \code{"mix-x-y"}, replacing \code{x} and \code{y} with any two of the scheme names listed above (e.g. "mix-teal-pink", "mix-blue-red", etc.). The order of \code{x} and \code{y} matters, i.e., the color schemes diff --git a/tests/figs/mcmc-traces/mcmc-rank-histogram-default.svg b/tests/figs/mcmc-traces/mcmc-rank-histogram-default.svg new file mode 100644 index 00000000..7a6aad97 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-histogram-default.svg @@ -0,0 +1,355 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +1 + + + + + + + + + + +2 + + + + + + + + + + +3 + + + + + + + + + + +4 + + + + + + + + + + +V1 + + + + + + + + + + +V2 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 +Rank +mcmc rank histogram (default) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-histogram-one-parameter.svg b/tests/figs/mcmc-traces/mcmc-rank-histogram-one-parameter.svg new file mode 100644 index 00000000..e1f41046 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-histogram-one-parameter.svg @@ -0,0 +1,209 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +V1, 1 + + + + + + + + + + +V1, 2 + + + + + + + + + + +V1, 3 + + + + + + + + + + +V1, 4 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 +Rank +mcmc rank histogram (one parameter) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-histogram-reference-line.svg b/tests/figs/mcmc-traces/mcmc-rank-histogram-reference-line.svg new file mode 100644 index 00000000..79366d1a --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-histogram-reference-line.svg @@ -0,0 +1,363 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +1 + + + + + + + + + + +2 + + + + + + + + + + +3 + + + + + + + + + + +4 + + + + + + + + + + +V1 + + + + + + + + + + +V2 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 +Rank +mcmc rank histogram (reference line) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-histogram-wide-bins.svg b/tests/figs/mcmc-traces/mcmc-rank-histogram-wide-bins.svg new file mode 100644 index 00000000..8884cd27 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-histogram-wide-bins.svg @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +V1, 1 + + + + + + + + + + +V1, 2 + + + + + + + + + + +V1, 3 + + + + + + + + + + +V1, 4 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 + +0 +500 +1000 +1500 +2000 +Rank +mcmc rank histogram (wide bins) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-overlay-default.svg b/tests/figs/mcmc-traces/mcmc-rank-overlay-default.svg new file mode 100644 index 00000000..cc020e24 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-overlay-default.svg @@ -0,0 +1,109 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +V1 + + + + + + + + + + +V2 + + + + + + + + + + + +0 +500 +1000 +1500 +2000 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +10 +20 +30 + + + + +Rank +Chain + + + + +1 +2 +3 +4 +mcmc rank overlay (default) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-overlay-one-parameter.svg b/tests/figs/mcmc-traces/mcmc-rank-overlay-one-parameter.svg new file mode 100644 index 00000000..bfc7c8a5 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-overlay-one-parameter.svg @@ -0,0 +1,72 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + +V1 + + + + + + + + + + + +0 +500 +1000 +1500 +2000 + +0 +10 +20 +30 + + + + +Rank +Chain + + + + +1 +2 +3 +4 +mcmc rank overlay (one parameter) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-overlay-reference-line.svg b/tests/figs/mcmc-traces/mcmc-rank-overlay-reference-line.svg new file mode 100644 index 00000000..1bd95203 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-overlay-reference-line.svg @@ -0,0 +1,111 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +V1 + + + + + + + + + + +V2 + + + + + + + + + + + +0 +500 +1000 +1500 +2000 + + + + + + +0 +500 +1000 +1500 +2000 + +0 +10 +20 +30 + + + + +Rank +Chain + + + + +1 +2 +3 +4 +mcmc rank overlay (reference line) + diff --git a/tests/figs/mcmc-traces/mcmc-rank-overlay-wide-bins.svg b/tests/figs/mcmc-traces/mcmc-rank-overlay-wide-bins.svg new file mode 100644 index 00000000..43cf8a25 --- /dev/null +++ b/tests/figs/mcmc-traces/mcmc-rank-overlay-wide-bins.svg @@ -0,0 +1,72 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + +V1 + + + + + + + + + + + +0 +500 +1000 +1500 +2000 + +0 +50 +100 +150 + + + + +Rank +Chain + + + + +1 +2 +3 +4 +mcmc rank overlay (wide bins) + diff --git a/tests/testthat/test-mcmc-traces.R b/tests/testthat/test-mcmc-traces.R index 6a45fda9..474402fd 100644 --- a/tests/testthat/test-mcmc-traces.R +++ b/tests/testthat/test-mcmc-traces.R @@ -117,6 +117,65 @@ test_that("mcmc_trace renders correctly", { vdiffr::expect_doppelganger("mcmc trace (iter1 offset)", p_iter1) }) +test_that("mcmc_rank_overlay renders correctly", { + testthat::skip_on_cran() + + p_base <- mcmc_rank_overlay(vdiff_dframe_chains, pars = c("V1", "V2")) + p_base_ref <- mcmc_rank_overlay( + vdiff_dframe_chains, + pars = c("V1", "V2"), + ref_line = TRUE + ) + p_one_param <- mcmc_rank_overlay(vdiff_dframe_chains, pars = "V1") + p_one_param_wide_bins <- mcmc_rank_overlay( + vdiff_dframe_chains, + pars = "V1", + n_bins = 4 + ) + + vdiffr::expect_doppelganger("mcmc rank overlay (default)", p_base) + vdiffr::expect_doppelganger( + "mcmc rank overlay (reference line)", + p_base_ref + ) + vdiffr::expect_doppelganger("mcmc rank overlay (one parameter)", p_one_param) + vdiffr::expect_doppelganger( + "mcmc rank overlay (wide bins)", + p_one_param_wide_bins + ) +}) + +test_that("mcmc_rank_hist renders correctly", { + testthat::skip_on_cran() + + p_base <- mcmc_rank_hist(vdiff_dframe_chains, pars = c("V1", "V2")) + p_base_ref <- mcmc_rank_hist( + vdiff_dframe_chains, + pars = c("V1", "V2"), + ref_line = TRUE + ) + p_one_param <- mcmc_rank_hist(vdiff_dframe_chains, pars = "V1") + p_one_param_wide_bins <- mcmc_rank_hist( + vdiff_dframe_chains, + pars = "V1", + n_bins = 4 + ) + + vdiffr::expect_doppelganger("mcmc rank histogram (default)", p_base) + vdiffr::expect_doppelganger( + "mcmc rank histogram (reference line)", + p_base_ref + ) + vdiffr::expect_doppelganger( + "mcmc rank histogram (one parameter)", + p_one_param + ) + vdiffr::expect_doppelganger( + "mcmc rank histogram (wide bins)", + p_one_param_wide_bins + ) +}) + test_that("mcmc_trace_highlight renders correctly", { testthat::skip_on_cran()