Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 59 additions & 45 deletions R/psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ psis <- function(log_ratios, ...) UseMethod("psis")
#' @template array
#'
psis.array <-
function(log_ratios, ...,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
importance_sampling.array(log_ratios = log_ratios, ...,
r_eff = r_eff,
cores = cores,
method = "psis")
function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) {
importance_sampling.array(
log_ratios = log_ratios,
...,
r_eff = r_eff,
cores = cores,
method = "psis"
)
}


Expand All @@ -113,15 +114,14 @@ psis.array <-
#' @template matrix
#'
psis.matrix <-
function(log_ratios,
...,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
importance_sampling.matrix(log_ratios,
...,
r_eff = r_eff,
cores = cores,
method = "psis")
function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) {
importance_sampling.matrix(
log_ratios,
...,
r_eff = r_eff,
cores = cores,
method = "psis"
)
}

#' @export
Expand All @@ -130,9 +130,12 @@ psis.matrix <-
#'
psis.default <-
function(log_ratios, ..., r_eff = 1) {
importance_sampling.default(log_ratios = log_ratios, ...,
r_eff = r_eff,
method = "psis")
importance_sampling.default(
log_ratios = log_ratios,
...,
r_eff = r_eff,
method = "psis"
)
}


Expand All @@ -149,25 +152,26 @@ is.psis <- function(x) {
#' @noRd
#' @seealso importance_sampling_object
psis_object <-
function(unnormalized_log_weights,
pareto_k,
tail_len,
r_eff) {
importance_sampling_object(unnormalized_log_weights = unnormalized_log_weights,
pareto_k = pareto_k,
tail_len = tail_len,
r_eff = r_eff,
method = "psis")
function(unnormalized_log_weights, pareto_k, tail_len, r_eff) {
importance_sampling_object(
unnormalized_log_weights = unnormalized_log_weights,
pareto_k = pareto_k,
tail_len = tail_len,
r_eff = r_eff,
method = "psis"
)
}


#' @noRd
#' @seealso do_importance_sampling
do_psis <- function(log_ratios, r_eff, cores, method){
do_importance_sampling(log_ratios = log_ratios,
r_eff = r_eff,
cores = cores,
method = "psis")
do_psis <- function(log_ratios, r_eff, cores, method) {
do_importance_sampling(
log_ratios = log_ratios,
r_eff = r_eff,
cores = cores,
method = "psis"
)
}

#' Extract named components from each list in the list of lists obtained by
Expand All @@ -181,7 +185,9 @@ do_psis <- function(log_ratios, r_eff, cores, method){
#' @return Numeric vector or matrix.
#'
psis_apply <- function(x, item, fun = c("[[", "attr"), fun_val = numeric(1)) {
if (!is.list(x)) stop("Internal error ('x' must be a list for psis_apply)")
if (!is.list(x)) {
stop("Internal error ('x' must be a list for psis_apply)")
}
vapply(x, FUN = match.arg(fun), FUN.VALUE = fun_val, item)
}

Expand Down Expand Up @@ -212,7 +218,7 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) {
ord <- sort.int(lw_i, index.return = TRUE)
tail_ids <- seq(S - tail_len_i + 1, S)
lw_tail <- ord$x[tail_ids]
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps/100) {
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) {
warning(
"Can't fit generalized Pareto distribution ",
"because all tail values are the same.",
Expand Down Expand Up @@ -252,11 +258,11 @@ psis_smooth_tail <- function(x, cutoff) {
k <- fit$k
sigma <- fit$sigma
if (is.finite(k)) {
p <- (seq_len(len) - 0.5) / len
qq <- qgpd(p, k, sigma) + exp_cutoff
tail <- log(qq)
p <- (seq_len(len) - 0.5) / len
qq <- qgpd(p, k, sigma) + exp_cutoff
tail <- log(qq)
} else {
tail <- x
tail <- x
}
list(tail = tail, k = k)
}
Expand Down Expand Up @@ -322,7 +328,8 @@ throw_tail_length_warnings <- function(tail_lengths) {
if (length(tail_lengths) == 1) {
warning(
"Not enough tail samples to fit the generalized Pareto distribution.",
call. = FALSE, immediate. = TRUE
call. = FALSE,
immediate. = TRUE
)
} else {
bad <- which(tail_len_bad)
Expand All @@ -332,7 +339,11 @@ throw_tail_length_warnings <- function(tail_lengths) {
"in some or all columns of matrix of log importance ratios. ",
"Skipping the following columns: ",
paste(if (Nbad <= 10) bad else bad[1:10], collapse = ", "),
if (Nbad > 10) paste0(", ... [", Nbad - 10, " more not printed].\n") else "\n",
if (Nbad > 10) {
paste0(", ... [", Nbad - 10, " more not printed].\n")
} else {
"\n"
},
call. = FALSE,
immediate. = TRUE
)
Expand All @@ -352,17 +363,21 @@ throw_tail_length_warnings <- function(tail_lengths) {
#' * If `r_eff` is `NA` then `rep(1, len)` is returned.
#' * If `r_eff` is a scalar then `rep(r_eff, len)` is returned.
#' * If `r_eff` is not a scalar but the length is not `len` then an error is thrown.
#' * If `r_eff` has length `len` but has `NA`s then an error is thrown.
#' * If `r_eff` has length `len` but has `NA`s then `NA`s are filled in with `1`s.
#'
prepare_psis_r_eff <- function(r_eff, len) {
if (isTRUE(is.null(r_eff) || all(is.na(r_eff)))) {
r_eff <- rep(1, len)
} else if (length(r_eff) == 1) {
r_eff <- rep(r_eff, len)
} else if (length(r_eff) != len) {
stop("'r_eff' must have one value or one value per observation.", call. = FALSE)
stop(
"'r_eff' must have one value or one value per observation.",
call. = FALSE
)
} else if (anyNA(r_eff)) {
stop("Can't mix NA and not NA values in 'r_eff'.", call. = FALSE)
message("Replacing NAs in `r_eff` with 1s")
r_eff[is.na(r_eff)] <- 1
}
r_eff
}
Expand Down Expand Up @@ -390,4 +405,3 @@ throw_psis_r_eff_warning <- function() {
call. = FALSE
)
}

13 changes: 13 additions & 0 deletions tests/testthat/_snaps/psis.md
Original file line number Diff line number Diff line change
Expand Up @@ -4782,6 +4782,19 @@

# psis throws correct errors and warnings

Code
psis(-LLarr, r_eff = r_eff_arr)
Message
Replacing NAs in `r_eff` with 1s
Output
Computed from 1000 by 32 log-weights matrix.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]).

All Pareto k estimates are good (k < 0.67).
See help('pareto-k-diagnostic') for details.

---

Code
psis(-LLarr[1:5, , ])
Condition
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/_snaps/tisis.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# tis throws correct errors and warnings

Code
psis(-LLarr, r_eff = r_eff_arr)
Message
Replacing NAs in `r_eff` with 1s
Output
Computed from 1000 by 32 log-weights matrix.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]).

All Pareto k estimates are good (k < 0.67).
See help('pareto-k-diagnostic') for details.

4 changes: 2 additions & 2 deletions tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ test_that("psis throws correct errors and warnings", {
# r_eff non-scalar wrong length is error
expect_error(psis(-LLarr, r_eff = r_eff_arr[-1]), "one value per observation")

# r_eff has some NA values causes error
# r_eff has some NA values which are replaced with 1
r_eff_arr[2] <- NA
expect_error(psis(-LLarr, r_eff = r_eff_arr), "mix NA and not NA values")
expect_snapshot(psis(-LLarr, r_eff = r_eff_arr))

# tail length warnings
expect_snapshot(psis(-LLarr[1:5, , ]))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_tisis.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ test_that("tis throws correct errors and warnings", {
# r_eff wrong length is error
expect_error(tis(-LLarr, r_eff = r_eff_arr[-1]), "one value per observation")

# r_eff has some NA values causes error
# r_eff has some NA values which are replaced with 1
r_eff_arr[2] <- NA
expect_error(tis(-LLarr, r_eff = r_eff_arr), "mix NA and not NA values")
expect_snapshot(psis(-LLarr, r_eff = r_eff_arr))

# no NAs or non-finite values allowed
LLmat[1, 1] <- NA
Expand Down