From 6ea15a37097058c7f6d1ab25c7d589c48b3298f7 Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Tue, 7 Oct 2025 09:40:44 -0700 Subject: [PATCH 1/4] Interpolate NAs with 1 --- R/psis.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/psis.R b/R/psis.R index 1b321d70..48b275ed 100644 --- a/R/psis.R +++ b/R/psis.R @@ -362,7 +362,8 @@ prepare_psis_r_eff <- function(r_eff, len) { } else if (length(r_eff) != len) { stop("'r_eff' must have one value or one value per observation.", call. = FALSE) } else if (anyNA(r_eff)) { - stop("Can't mix NA and not NA values in 'r_eff'.", call. = FALSE) + message("If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.") + r_eff[is.na(r_eff)] <- 1 } r_eff } From cb48ff9328f52761bec37f4a8e106012e3c6dffe Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:13:21 -0700 Subject: [PATCH 2/4] Updated snapshots --- tests/testthat/_snaps/psis.md | 13 +++++++++++++ tests/testthat/_snaps/tisis.md | 13 +++++++++++++ tests/testthat/test_psis.R | 4 ++-- tests/testthat/test_tisis.R | 4 ++-- 4 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/_snaps/tisis.md diff --git a/tests/testthat/_snaps/psis.md b/tests/testthat/_snaps/psis.md index 9836dc49..b812ba31 100644 --- a/tests/testthat/_snaps/psis.md +++ b/tests/testthat/_snaps/psis.md @@ -4782,6 +4782,19 @@ # psis throws correct errors and warnings + Code + psis(-LLarr, r_eff = r_eff_arr) + Message + If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Output + Computed from 1000 by 32 log-weights matrix. + MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]). + + All Pareto k estimates are good (k < 0.67). + See help('pareto-k-diagnostic') for details. + +--- + Code psis(-LLarr[1:5, , ]) Condition diff --git a/tests/testthat/_snaps/tisis.md b/tests/testthat/_snaps/tisis.md new file mode 100644 index 00000000..4bf08d6c --- /dev/null +++ b/tests/testthat/_snaps/tisis.md @@ -0,0 +1,13 @@ +# tis throws correct errors and warnings + + Code + psis(-LLarr, r_eff = r_eff_arr) + Message + If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Output + Computed from 1000 by 32 log-weights matrix. + MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]). + + All Pareto k estimates are good (k < 0.67). + See help('pareto-k-diagnostic') for details. + diff --git a/tests/testthat/test_psis.R b/tests/testthat/test_psis.R index de0a5512..d43b6ed0 100644 --- a/tests/testthat/test_psis.R +++ b/tests/testthat/test_psis.R @@ -70,9 +70,9 @@ test_that("psis throws correct errors and warnings", { # r_eff non-scalar wrong length is error expect_error(psis(-LLarr, r_eff = r_eff_arr[-1]), "one value per observation") - # r_eff has some NA values causes error + # r_eff has some NA values which are replaced with 1 r_eff_arr[2] <- NA - expect_error(psis(-LLarr, r_eff = r_eff_arr), "mix NA and not NA values") + expect_snapshot(psis(-LLarr, r_eff = r_eff_arr)) # tail length warnings expect_snapshot(psis(-LLarr[1:5, , ])) diff --git a/tests/testthat/test_tisis.R b/tests/testthat/test_tisis.R index 8616ef28..44fcc12d 100644 --- a/tests/testthat/test_tisis.R +++ b/tests/testthat/test_tisis.R @@ -107,9 +107,9 @@ test_that("tis throws correct errors and warnings", { # r_eff wrong length is error expect_error(tis(-LLarr, r_eff = r_eff_arr[-1]), "one value per observation") - # r_eff has some NA values causes error + # r_eff has some NA values which are replaced with 1 r_eff_arr[2] <- NA - expect_error(tis(-LLarr, r_eff = r_eff_arr), "mix NA and not NA values") + expect_snapshot(psis(-LLarr, r_eff = r_eff_arr)) # no NAs or non-finite values allowed LLmat[1, 1] <- NA From d77ced9ff1a4473d70b0c2f2c16b033f1dc34adc Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Tue, 7 Oct 2025 12:12:35 -0700 Subject: [PATCH 3/4] Updated error message --- R/psis.R | 101 +++++++++++++++++++-------------- tests/testthat/_snaps/psis.md | 2 +- tests/testthat/_snaps/tisis.md | 2 +- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/R/psis.R b/R/psis.R index 48b275ed..29b0dfe6 100644 --- a/R/psis.R +++ b/R/psis.R @@ -98,13 +98,14 @@ psis <- function(log_ratios, ...) UseMethod("psis") #' @template array #' psis.array <- - function(log_ratios, ..., - r_eff = 1, - cores = getOption("mc.cores", 1)) { - importance_sampling.array(log_ratios = log_ratios, ..., - r_eff = r_eff, - cores = cores, - method = "psis") + function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) { + importance_sampling.array( + log_ratios = log_ratios, + ..., + r_eff = r_eff, + cores = cores, + method = "psis" + ) } @@ -113,15 +114,14 @@ psis.array <- #' @template matrix #' psis.matrix <- - function(log_ratios, - ..., - r_eff = 1, - cores = getOption("mc.cores", 1)) { - importance_sampling.matrix(log_ratios, - ..., - r_eff = r_eff, - cores = cores, - method = "psis") + function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) { + importance_sampling.matrix( + log_ratios, + ..., + r_eff = r_eff, + cores = cores, + method = "psis" + ) } #' @export @@ -130,9 +130,12 @@ psis.matrix <- #' psis.default <- function(log_ratios, ..., r_eff = 1) { - importance_sampling.default(log_ratios = log_ratios, ..., - r_eff = r_eff, - method = "psis") + importance_sampling.default( + log_ratios = log_ratios, + ..., + r_eff = r_eff, + method = "psis" + ) } @@ -149,25 +152,26 @@ is.psis <- function(x) { #' @noRd #' @seealso importance_sampling_object psis_object <- - function(unnormalized_log_weights, - pareto_k, - tail_len, - r_eff) { - importance_sampling_object(unnormalized_log_weights = unnormalized_log_weights, - pareto_k = pareto_k, - tail_len = tail_len, - r_eff = r_eff, - method = "psis") + function(unnormalized_log_weights, pareto_k, tail_len, r_eff) { + importance_sampling_object( + unnormalized_log_weights = unnormalized_log_weights, + pareto_k = pareto_k, + tail_len = tail_len, + r_eff = r_eff, + method = "psis" + ) } #' @noRd #' @seealso do_importance_sampling -do_psis <- function(log_ratios, r_eff, cores, method){ - do_importance_sampling(log_ratios = log_ratios, - r_eff = r_eff, - cores = cores, - method = "psis") +do_psis <- function(log_ratios, r_eff, cores, method) { + do_importance_sampling( + log_ratios = log_ratios, + r_eff = r_eff, + cores = cores, + method = "psis" + ) } #' Extract named components from each list in the list of lists obtained by @@ -181,7 +185,9 @@ do_psis <- function(log_ratios, r_eff, cores, method){ #' @return Numeric vector or matrix. #' psis_apply <- function(x, item, fun = c("[[", "attr"), fun_val = numeric(1)) { - if (!is.list(x)) stop("Internal error ('x' must be a list for psis_apply)") + if (!is.list(x)) { + stop("Internal error ('x' must be a list for psis_apply)") + } vapply(x, FUN = match.arg(fun), FUN.VALUE = fun_val, item) } @@ -212,7 +218,7 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) { ord <- sort.int(lw_i, index.return = TRUE) tail_ids <- seq(S - tail_len_i + 1, S) lw_tail <- ord$x[tail_ids] - if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps/100) { + if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) { warning( "Can't fit generalized Pareto distribution ", "because all tail values are the same.", @@ -252,11 +258,11 @@ psis_smooth_tail <- function(x, cutoff) { k <- fit$k sigma <- fit$sigma if (is.finite(k)) { - p <- (seq_len(len) - 0.5) / len - qq <- qgpd(p, k, sigma) + exp_cutoff - tail <- log(qq) + p <- (seq_len(len) - 0.5) / len + qq <- qgpd(p, k, sigma) + exp_cutoff + tail <- log(qq) } else { - tail <- x + tail <- x } list(tail = tail, k = k) } @@ -322,7 +328,8 @@ throw_tail_length_warnings <- function(tail_lengths) { if (length(tail_lengths) == 1) { warning( "Not enough tail samples to fit the generalized Pareto distribution.", - call. = FALSE, immediate. = TRUE + call. = FALSE, + immediate. = TRUE ) } else { bad <- which(tail_len_bad) @@ -332,7 +339,11 @@ throw_tail_length_warnings <- function(tail_lengths) { "in some or all columns of matrix of log importance ratios. ", "Skipping the following columns: ", paste(if (Nbad <= 10) bad else bad[1:10], collapse = ", "), - if (Nbad > 10) paste0(", ... [", Nbad - 10, " more not printed].\n") else "\n", + if (Nbad > 10) { + paste0(", ... [", Nbad - 10, " more not printed].\n") + } else { + "\n" + }, call. = FALSE, immediate. = TRUE ) @@ -360,9 +371,12 @@ prepare_psis_r_eff <- function(r_eff, len) { } else if (length(r_eff) == 1) { r_eff <- rep(r_eff, len) } else if (length(r_eff) != len) { - stop("'r_eff' must have one value or one value per observation.", call. = FALSE) + stop( + "'r_eff' must have one value or one value per observation.", + call. = FALSE + ) } else if (anyNA(r_eff)) { - message("If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.") + message("Replacing NAs in `r_eff` with 1s") r_eff[is.na(r_eff)] <- 1 } r_eff @@ -391,4 +405,3 @@ throw_psis_r_eff_warning <- function() { call. = FALSE ) } - diff --git a/tests/testthat/_snaps/psis.md b/tests/testthat/_snaps/psis.md index b812ba31..9f2deaf3 100644 --- a/tests/testthat/_snaps/psis.md +++ b/tests/testthat/_snaps/psis.md @@ -4785,7 +4785,7 @@ Code psis(-LLarr, r_eff = r_eff_arr) Message - If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Replacing NAs in `r_eff` with 1s Output Computed from 1000 by 32 log-weights matrix. MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]). diff --git a/tests/testthat/_snaps/tisis.md b/tests/testthat/_snaps/tisis.md index 4bf08d6c..57cf7e30 100644 --- a/tests/testthat/_snaps/tisis.md +++ b/tests/testthat/_snaps/tisis.md @@ -3,7 +3,7 @@ Code psis(-LLarr, r_eff = r_eff_arr) Message - If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Replacing NAs in `r_eff` with 1s Output Computed from 1000 by 32 log-weights matrix. MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]). From 06a70bab8bf83acdec6fb992f89eac3a3958cf31 Mon Sep 17 00:00:00 2001 From: VisruthSK <67435125+VisruthSK@users.noreply.github.com> Date: Tue, 7 Oct 2025 12:12:35 -0700 Subject: [PATCH 4/4] Updated error message --- R/psis.R | 103 +++++++++++++++++++-------------- tests/testthat/_snaps/psis.md | 2 +- tests/testthat/_snaps/tisis.md | 2 +- 3 files changed, 60 insertions(+), 47 deletions(-) diff --git a/R/psis.R b/R/psis.R index 48b275ed..ce5207d8 100644 --- a/R/psis.R +++ b/R/psis.R @@ -98,13 +98,14 @@ psis <- function(log_ratios, ...) UseMethod("psis") #' @template array #' psis.array <- - function(log_ratios, ..., - r_eff = 1, - cores = getOption("mc.cores", 1)) { - importance_sampling.array(log_ratios = log_ratios, ..., - r_eff = r_eff, - cores = cores, - method = "psis") + function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) { + importance_sampling.array( + log_ratios = log_ratios, + ..., + r_eff = r_eff, + cores = cores, + method = "psis" + ) } @@ -113,15 +114,14 @@ psis.array <- #' @template matrix #' psis.matrix <- - function(log_ratios, - ..., - r_eff = 1, - cores = getOption("mc.cores", 1)) { - importance_sampling.matrix(log_ratios, - ..., - r_eff = r_eff, - cores = cores, - method = "psis") + function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) { + importance_sampling.matrix( + log_ratios, + ..., + r_eff = r_eff, + cores = cores, + method = "psis" + ) } #' @export @@ -130,9 +130,12 @@ psis.matrix <- #' psis.default <- function(log_ratios, ..., r_eff = 1) { - importance_sampling.default(log_ratios = log_ratios, ..., - r_eff = r_eff, - method = "psis") + importance_sampling.default( + log_ratios = log_ratios, + ..., + r_eff = r_eff, + method = "psis" + ) } @@ -149,25 +152,26 @@ is.psis <- function(x) { #' @noRd #' @seealso importance_sampling_object psis_object <- - function(unnormalized_log_weights, - pareto_k, - tail_len, - r_eff) { - importance_sampling_object(unnormalized_log_weights = unnormalized_log_weights, - pareto_k = pareto_k, - tail_len = tail_len, - r_eff = r_eff, - method = "psis") + function(unnormalized_log_weights, pareto_k, tail_len, r_eff) { + importance_sampling_object( + unnormalized_log_weights = unnormalized_log_weights, + pareto_k = pareto_k, + tail_len = tail_len, + r_eff = r_eff, + method = "psis" + ) } #' @noRd #' @seealso do_importance_sampling -do_psis <- function(log_ratios, r_eff, cores, method){ - do_importance_sampling(log_ratios = log_ratios, - r_eff = r_eff, - cores = cores, - method = "psis") +do_psis <- function(log_ratios, r_eff, cores, method) { + do_importance_sampling( + log_ratios = log_ratios, + r_eff = r_eff, + cores = cores, + method = "psis" + ) } #' Extract named components from each list in the list of lists obtained by @@ -181,7 +185,9 @@ do_psis <- function(log_ratios, r_eff, cores, method){ #' @return Numeric vector or matrix. #' psis_apply <- function(x, item, fun = c("[[", "attr"), fun_val = numeric(1)) { - if (!is.list(x)) stop("Internal error ('x' must be a list for psis_apply)") + if (!is.list(x)) { + stop("Internal error ('x' must be a list for psis_apply)") + } vapply(x, FUN = match.arg(fun), FUN.VALUE = fun_val, item) } @@ -212,7 +218,7 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) { ord <- sort.int(lw_i, index.return = TRUE) tail_ids <- seq(S - tail_len_i + 1, S) lw_tail <- ord$x[tail_ids] - if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps/100) { + if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) { warning( "Can't fit generalized Pareto distribution ", "because all tail values are the same.", @@ -252,11 +258,11 @@ psis_smooth_tail <- function(x, cutoff) { k <- fit$k sigma <- fit$sigma if (is.finite(k)) { - p <- (seq_len(len) - 0.5) / len - qq <- qgpd(p, k, sigma) + exp_cutoff - tail <- log(qq) + p <- (seq_len(len) - 0.5) / len + qq <- qgpd(p, k, sigma) + exp_cutoff + tail <- log(qq) } else { - tail <- x + tail <- x } list(tail = tail, k = k) } @@ -322,7 +328,8 @@ throw_tail_length_warnings <- function(tail_lengths) { if (length(tail_lengths) == 1) { warning( "Not enough tail samples to fit the generalized Pareto distribution.", - call. = FALSE, immediate. = TRUE + call. = FALSE, + immediate. = TRUE ) } else { bad <- which(tail_len_bad) @@ -332,7 +339,11 @@ throw_tail_length_warnings <- function(tail_lengths) { "in some or all columns of matrix of log importance ratios. ", "Skipping the following columns: ", paste(if (Nbad <= 10) bad else bad[1:10], collapse = ", "), - if (Nbad > 10) paste0(", ... [", Nbad - 10, " more not printed].\n") else "\n", + if (Nbad > 10) { + paste0(", ... [", Nbad - 10, " more not printed].\n") + } else { + "\n" + }, call. = FALSE, immediate. = TRUE ) @@ -352,7 +363,7 @@ throw_tail_length_warnings <- function(tail_lengths) { #' * If `r_eff` is `NA` then `rep(1, len)` is returned. #' * If `r_eff` is a scalar then `rep(r_eff, len)` is returned. #' * If `r_eff` is not a scalar but the length is not `len` then an error is thrown. -#' * If `r_eff` has length `len` but has `NA`s then an error is thrown. +#' * If `r_eff` has length `len` but has `NA`s then `NA`s are filled in with `1`s. #' prepare_psis_r_eff <- function(r_eff, len) { if (isTRUE(is.null(r_eff) || all(is.na(r_eff)))) { @@ -360,9 +371,12 @@ prepare_psis_r_eff <- function(r_eff, len) { } else if (length(r_eff) == 1) { r_eff <- rep(r_eff, len) } else if (length(r_eff) != len) { - stop("'r_eff' must have one value or one value per observation.", call. = FALSE) + stop( + "'r_eff' must have one value or one value per observation.", + call. = FALSE + ) } else if (anyNA(r_eff)) { - message("If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.") + message("Replacing NAs in `r_eff` with 1s") r_eff[is.na(r_eff)] <- 1 } r_eff @@ -391,4 +405,3 @@ throw_psis_r_eff_warning <- function() { call. = FALSE ) } - diff --git a/tests/testthat/_snaps/psis.md b/tests/testthat/_snaps/psis.md index b812ba31..9f2deaf3 100644 --- a/tests/testthat/_snaps/psis.md +++ b/tests/testthat/_snaps/psis.md @@ -4785,7 +4785,7 @@ Code psis(-LLarr, r_eff = r_eff_arr) Message - If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Replacing NAs in `r_eff` with 1s Output Computed from 1000 by 32 log-weights matrix. MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]). diff --git a/tests/testthat/_snaps/tisis.md b/tests/testthat/_snaps/tisis.md index 4bf08d6c..57cf7e30 100644 --- a/tests/testthat/_snaps/tisis.md +++ b/tests/testthat/_snaps/tisis.md @@ -3,7 +3,7 @@ Code psis(-LLarr, r_eff = r_eff_arr) Message - If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. + Replacing NAs in `r_eff` with 1s Output Computed from 1000 by 32 log-weights matrix. MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]).