This vignette demonstrates application of glmnet and varbvs to the Leukemia data set. The main aim of this script is to illustrate some of the different properties of Bayesian variable selection and penalized sparse regression (as implemented by varbvs and glmnet, respectively).
We use the preprocessed data of Dettling (2004) retrieved from the
supplementary materials accompanying Friedman et al (2010). The
data are represented as a 72 x 3,571 matrix of gene expression values
(variable X
), and a vector of 72 binary disease outcomes
(variable y
).
Begin by loading these packages into your R environment.
Specify settings for the glmnet analysis.
Also set the random number generator seed.
Here, we also run 20-fold cross-validation to select the largest setting of the L1-penalty strength (lambda) that is within 1 standard error of the minimum classification error.
# This is the model fitting step.
r <- system.time(fit.glmnet <-
glmnet(X,y,family = "binomial",lambda = lambda,alpha = alpha))
cat(sprintf("Model fitting took %0.2f seconds.\n",r["elapsed"]))
# Model fitting took 0.02 seconds.
# This is the cross-validation step.
r <- system.time(out.cv.glmnet <-
cv.glmnet(X,y,family = "binomial",type.measure = "class",
alpha = alpha,nfolds = nfolds,lambda = lambda))
lambda <- out.cv.glmnet$lambda
cat(sprintf("Cross-validation took %0.2f seconds.\n",r["elapsed"]))
# Cross-validation took 0.32 seconds.
# Choose the largest value of lambda that is within 1 standard error
# of the smallest misclassification error.
lambda.opt <- out.cv.glmnet$lambda.1se
Compute estimates of the disease outcome using the fitted model, and compare against the observed values.
The first plot shows the evolution of regression coefficients at different settings of lambda. (Note that the intercept is not shown.) Only the curves for the variables that are selected at the optimal setting of lambda (“lambda.opt”“) are labeled.
The second plot shows the classification error at different settings of lambda.
The third plot shows the number of nonzero regression coefficients at different settings of lambda.
trellis.par.set(par.xlab.text = list(cex = 0.85),
par.ylab.text = list(cex = 0.85),
axis.text = list(cex = 0.75))
# Choose the largest value of lambda that is within 1 standard error
# of the smallest misclassification error.
lambda.opt <- out.cv.glmnet$lambda.1se
# Plot regression coefficients.
lambda <- fit.glmnet$lambda
vars <- setdiff(which(rowSums(abs(coef(fit.glmnet))) > 0),1)
n <- length(vars)
b <- as.matrix(t(coef(fit.glmnet)[vars,]))
i <- coef(fit.glmnet,s = lambda.opt)
i <- rownames(i)[which(i != 0)]
i <- i[-1]
vars.opt <- colnames(b)
vars.opt[!is.element(vars.opt,i)] <- ""
vars.opt <- substring(vars.opt,2)
lab <- expression("more complex" %<-% paste(log[10],lambda) %->%
"less complex")
r <- xyplot(y ~ x,data.frame(x = log10(lambda),y = b[,1]),type = "l",
col = "blue",xlab = lab,ylab = "regression coefficient",
scales = list(x = list(limits = c(-2.35,0.1)),
y = list(limits = c(-0.8,1.2))),
panel = function(x, y, ...) {
panel.xyplot(x,y,...);
panel.abline(v = log10(lambda.opt),col = "orangered",
lwd = 2,lty = "dotted");
ltext(x = -2,y = b[nrow(b),],labels = vars.opt,pos = 2,
offset = 0.5,cex = 0.5);
})
for (i in 2:n)
r <- r + as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = b[,i]),
type = "l",col = "blue"))
print(r,split = c(2,1,2,1),more = TRUE)
# Plot classification error.
Y <- predict(fit.glmnet,X,type = "class")
mode(Y) <- "numeric"
print(with(out.cv.glmnet,
xyplot(y ~ x,data.frame(x = log10(lambda),y = cvm),type = "l",
col = "blue",xlab = lab,
ylab = "20-fold cross-validation \n classification error",
scales = list(y = list(limits = c(-0.02,0.45))),
panel = function(x, y, ...) {
panel.xyplot(x,y,...);
panel.abline(v = log10(lambda.opt),col = "orangered",
lwd = 2,lty = "dotted");
}) +
as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvm),
pch = 20,cex = 0.6,col = "blue")) +
as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvup),
type = "l",col = "blue",lty = "solid")) +
as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvlo),
type = "l",col = "blue",lty = "solid")) +
as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),
y = colMeans(abs(Y - y))),
type = "l",col = "darkorange",lwd = 2,
lty = "solid"))),
split = c(1,1,2,2),more = TRUE)
# Plot number of non-zero regression coefficients.
print(with(out.cv.glmnet,
xyplot(y ~ x,data.frame(x = log10(lambda),y = nzero),type = "l",
col = "blue",xlab = lab,
ylab = "number of non-zero \n coefficients",
panel = function(x, y, ...) {
panel.xyplot(x,y,...)
panel.abline(v = log10(lambda.opt),col = "orangered",
lwd = 2,lty = "dotted")
}) +
as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = nzero),
pch = 20,cex = 0.6,col = "blue"))),
split = c(1,2,2,2),more = FALSE)
Fit the fully-factorized variational approximation to the posterior distribution of the coefficients for a logistic regression model of the binary outcome (the type of leukemia), with spike-and-slab priors on the coefficients.
Compute estimates of the disease outcome using the fitted model, and compare against the observed values.
The first plot shows the classification error at each setting of the prior log-odds.
The second plot shows the evolution of the posterior mean regression coefficients (the beta’s) at different settings of the prior log-odds, for the top 6 variables ranked by posterior inclusion probability (averaged over settings of the hyperparameters).
The top-ranked variable (by posterior inclusion probability) has a much larger coefficient than all the others, so it is shown in a separate plot.
The third plot shows the (approximate) probability density of the prior log-odds parameter.
trellis.par.set(par.xlab.text = list(cex = 0.85),
par.ylab.text = list(cex = 0.85),
axis.text = list(cex = 0.75))
# Get the normalized importance weights.
w <- fit.varbvs$w
# Plot classification error at each hyperparameter setting.
sigmoid10 <- function (x)
1/(1 + 10^(-x))
logodds <- fit.varbvs$logodds
log10q <- log10(sigmoid10(logodds))
m <- length(logodds)
err <- rep(0,m)
for (i in 1:m) {
r <- subset(fit.varbvs,logodds == logodds[i])
ypred <- predict(r,X)
err[i] <- mean(y != ypred)
}
lab <- expression("more complex" %<-% paste(log[10],pi) %->% "less complex")
print(xyplot(y ~ x,data.frame(x = log10q,y = err),type = "l",
col = "blue",xlab = lab,ylab = "classification error",
scales = list(x = list(limits = c(-0.9,-3.65)))) +
as.layer(xyplot(y ~ x,data.frame(x = log10q,y = err),
col = "blue",pch = 20,cex = 0.65)),
split = c(1,1,2,2),more = TRUE)
# Plot expected number of included variables at each hyperparameter
# setting.
r <- colSums(fit.varbvs$alpha)
print(xyplot(y ~ x,data.frame(x = log10q,y = r),type = "l",col = "blue",
xlab = lab,ylab = "expected number of\nincluded variables",
scales = list(x = list(limits = c(-0.9,-3.65)),
y = list(log = 10,at = c(1,10,100)))) +
as.layer(xyplot(y ~ x,data.frame(x = log10q,y = r),
col = "blue",pch = 20,cex = 0.65,
scales = list(x = list(limits = c(-0.9,-3.65)),
y = list(log = 10)))),
split = c(1,2,2,2),more = TRUE)
# Plot density of prior inclusion probability hyperparameter.
print(xyplot(y ~ x,data.frame(x = log10q,y = w),type = "l",col = "blue",
xlab = lab,
ylab = expression(paste("posterior probability of ",pi)),
scales = list(x = list(limits = c(-0.9,-3.65)))) +
as.layer(xyplot(y ~ x,data.frame(x = log10q,y = w),
col = "blue",pch = 20,cex = 0.65)),
split = c(2,1,2,1),more = FALSE)
Dettling, M. (2004). BagBoosting for tumor classification with gene expression data. Bioinformatics 20, 3583–3593.
Friedman, J., Hastie, T., Tibshirani, R. (2010) Regularization paths for generalized linear models via coordinate descent. Journal of Statistical Software 33, 1–22.
This is the version of R and the packages that were used to generate these results.
sessionInfo()
# R version 4.4.2 (2024-10-31)
# Platform: x86_64-pc-linux-gnu
# Running under: Ubuntu 24.04.1 LTS
#
# Matrix products: default
# BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
# LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
#
# locale:
# [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
# [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
# [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
# [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
# [9] LC_ADDRESS=C LC_TELEPHONE=C
# [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
#
# time zone: Etc/UTC
# tzcode source: system (glibc)
#
# attached base packages:
# [1] stats graphics grDevices utils datasets methods base
#
# other attached packages:
# [1] glmnet_4.1-8 Matrix_1.7-1 latticeExtra_0.6-30
# [4] curl_6.0.0 varbvs_2.6-10 lattice_0.22-6
# [7] rmarkdown_2.29
#
# loaded via a namespace (and not attached):
# [1] cli_3.6.3 knitr_1.49 rlang_1.1.4 xfun_0.49
# [5] nor1mix_1.3-3 png_0.1-8 jsonlite_1.8.9 buildtools_1.0.0
# [9] htmltools_0.5.8.1 maketools_1.3.1 sys_3.4.3 sass_0.4.9
# [13] grid_4.4.2 evaluate_1.0.1 jquerylib_0.1.4 fastmap_1.2.0
# [17] foreach_1.5.2 interp_1.1-6 yaml_2.3.10 lifecycle_1.0.4
# [21] compiler_4.4.2 codetools_0.2-20 RColorBrewer_1.1-3 Rcpp_1.0.13-1
# [25] digest_0.6.37 R6_2.5.1 splines_4.4.2 shape_1.4.6.1
# [29] bslib_0.8.0 tools_4.4.2 jpeg_0.1-10 iterators_1.0.14
# [33] deldir_2.0-4 survival_3.7-0 cachem_1.1.0