Skip to content

Commit ceac3f9

Browse files
When ivn.rand = FALSE, check ivn argument carefully (fixes #4)
1 parent 10cdda4 commit ceac3f9

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

R/ccdrAlgorithm-mvn.R

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@
3838
#' params = gr.params,
3939
#' ivn = ivn)
4040
#'
41+
#' ### Use pre-specified values for interventions
42+
#' ### In this toy example, we assume that all intervened nodes were fixed to
43+
#' ### to the value 1, although this can be any number of course.
44+
#' ivn.vals <- lapply(ivn, function(x) sapply(x, function(x) 1)) # replace all entries with a 1
45+
#' data.ivn <- ccdrAlgorithm::generate_mvn_data(graph = gr,
46+
#' n = 100,
47+
#' params = gr.params,
48+
#' ivn = ivn.vals,
49+
#' ivn.rand = FALSE)
50+
#'
51+
#' ### If ivn.rand = FALSE, you must specify values
52+
#' ### The code below will fail because ivn does not contain any values
53+
#' ### (compare to ivn.vals above).
54+
#' \dontrun{
55+
#' data.ivn <- ccdrAlgorithm::generate_mvn_data(graph = gr,
56+
#' n = 100,
57+
#' params = gr.params,
58+
#' ivn = ivn,
59+
#' ivn.rand = FALSE)
60+
#' }
61+
#'
4162
#' @export
4263
generate_mvn_data <- function(graph, params, n = 1, ivn = NULL, ivn.rand = TRUE){
4364
### This function requires the 'igraph' package to be installed
@@ -61,6 +82,20 @@ generate_mvn_data <- function(graph, params, n = 1, ivn = NULL, ivn.rand = TRUE)
6182
if(ivn.rand){
6283
ivn <- lapply(ivn, function(x) sapply(x, function(x) rnorm(n = 1, mean = 0, sd = 1))) # assume standard normal
6384
# ivn <- lapply(ivn, function(x) sapply(x, function(x) 1)) # debugging
85+
} else{
86+
check_vals <- sparsebnUtils::check_list_class(ivn, c("NULL", "numeric")) # check to make sure list components are either numeric (ivn vals) or NULL (obs sample)
87+
check_names <- sapply(ivn, function(x) is.null(names(x))) # return TRUE if component has no names attribute (i.e. it is NULL)
88+
89+
if(!check_vals || all(check_names)){
90+
err_msg <- paste0("ivn.rand set to FALSE with invalid input for ivn: ",
91+
"If ivn.rand = FALSE, you must pass explicit values ",
92+
"for each intervention used in your experiments. ",
93+
"Please check that the ivn argument is a list whose ",
94+
"arguments are named numeric vectors whose names ",
95+
"correspond to the node under intervention or NULL ",
96+
"if the corresponding row is observational.")
97+
stop(err_msg)
98+
}
6499
}
65100
}
66101

man/generate_mvn_data.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-mvn.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,17 @@ test_that("Generate MVN data w/ interventions", {
2323
expect_equal(as.vector(dat[, 1:nivn]), rep(1, nobs*nivn)) # first three columns are all ones
2424
})
2525

26+
test_that("ivn.rand = FALSE with invalid input returns an error (fixes issue #4)", {
27+
nobs <- 10
28+
nivn <- 3
29+
30+
### character / node names input
31+
ivn <- lapply(1:nobs, function(x) names(el)[1:nivn]) # only intervene on first 3 nodes
32+
expect_error(generate_mvn_data(el, params, n = nobs, ivn = ivn, ivn.rand = FALSE),
33+
"ivn.rand set to FALSE with invalid input")
34+
35+
### numeric / index input
36+
ivn <- lapply(1:nobs, function(x) 1:nivn) # only intervene on first 3 nodes
37+
expect_error(generate_mvn_data(el, params, n = nobs, ivn = ivn, ivn.rand = FALSE),
38+
"ivn.rand set to FALSE with invalid input")
39+
})

0 commit comments

Comments
 (0)