Comparison of glmnet and varbvs in Leukemia data set

This vignette demonstrates application of glmnet and varbvs to the Leukemia data set. The main aim of this script is to illustrate some of the different properties of Bayesian variable selection and penalized sparse regression (as implemented by varbvs and glmnet, respectively).

We use the preprocessed data of Dettling (2004) retrieved from the supplementary materials accompanying Friedman et al (2010). The data are represented as a 72 x 3,571 matrix of gene expression values (variable X), and a vector of 72 binary disease outcomes (variable y).

Vignette parameters

Begin by loading these packages into your R environment.

library(lattice)
library(latticeExtra)
library(glmnet)
library(varbvs)

Specify settings for the glmnet analysis.

nfolds <- 20                    # Number of cross-validation folds.
alpha  <- 0.95                  # Elastic net mixing parameter.
lambda <- 10^(seq(0,-2,-0.05))  # Lambda sequence.

Load the Leukemia data

Also set the random number generator seed.

data(leukemia)
X <- leukemia$x
y <- leukemia$y
set.seed(1)

Fit elastic net model to data

Here, we also run 20-fold cross-validation to select the largest setting of the L1-penalty strength (lambda) that is within 1 standard error of the minimum classification error.

# This is the model fitting step.
r <- system.time(fit.glmnet <-
       glmnet(X,y,family = "binomial",lambda = lambda,alpha = alpha))
cat(sprintf("Model fitting took %0.2f seconds.\n",r["elapsed"]))
# Model fitting took 0.02 seconds.

# This is the cross-validation step.
r <- system.time(out.cv.glmnet <-
       cv.glmnet(X,y,family = "binomial",type.measure = "class",
                 alpha = alpha,nfolds = nfolds,lambda = lambda))
lambda <- out.cv.glmnet$lambda
cat(sprintf("Cross-validation took %0.2f seconds.\n",r["elapsed"]))
# Cross-validation took 0.32 seconds.

# Choose the largest value of lambda that is within 1 standard error
# of the smallest misclassification error.
lambda.opt <- out.cv.glmnet$lambda.1se

Evaluate the glmnet predictions

Compute estimates of the disease outcome using the fitted model, and compare against the observed values.

cat("classification results with lambda = ",lambda.opt,":\n",sep="")
y.glmnet <- c(predict(fit.glmnet,X,s = lambda.opt,type = "class"))
print(table(true = factor(y),pred = factor(y.glmnet)))
# classification results with lambda = 0.02818383:
#     pred
# true  0  1
#    0 47  0
#    1  0 25

Visualize results of glmnet analysis

The first plot shows the evolution of regression coefficients at different settings of lambda. (Note that the intercept is not shown.) Only the curves for the variables that are selected at the optimal setting of lambda (“lambda.opt”“) are labeled.

The second plot shows the classification error at different settings of lambda.

The third plot shows the number of nonzero regression coefficients at different settings of lambda.

trellis.par.set(par.xlab.text = list(cex = 0.85),
                par.ylab.text = list(cex = 0.85),
                axis.text     = list(cex = 0.75))

# Choose the largest value of lambda that is within 1 standard error
# of the smallest misclassification error.
lambda.opt <- out.cv.glmnet$lambda.1se

# Plot regression coefficients.
lambda   <- fit.glmnet$lambda
vars     <- setdiff(which(rowSums(abs(coef(fit.glmnet))) > 0),1)
n        <- length(vars)
b        <- as.matrix(t(coef(fit.glmnet)[vars,]))
i        <- coef(fit.glmnet,s = lambda.opt)
i        <- rownames(i)[which(i != 0)]
i        <- i[-1]
vars.opt <- colnames(b)
vars.opt[!is.element(vars.opt,i)] <- ""
vars.opt <- substring(vars.opt,2)
lab  <- expression("more complex" %<-% paste(log[10],lambda) %->% 
                   "less complex")
r    <- xyplot(y ~ x,data.frame(x = log10(lambda),y = b[,1]),type = "l",
               col = "blue",xlab = lab,ylab = "regression coefficient",
               scales = list(x = list(limits = c(-2.35,0.1)),
                             y = list(limits = c(-0.8,1.2))),
               panel = function(x, y, ...) {
                 panel.xyplot(x,y,...);
                 panel.abline(v = log10(lambda.opt),col = "orangered",
                              lwd = 2,lty = "dotted");
                 ltext(x = -2,y = b[nrow(b),],labels = vars.opt,pos = 2,
                       offset = 0.5,cex = 0.5);
               })
for (i in 2:n)
  r <- r + as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = b[,i]),
                           type = "l",col = "blue"))
print(r,split = c(2,1,2,1),more = TRUE)

# Plot classification error.
Y       <- predict(fit.glmnet,X,type = "class")
mode(Y) <- "numeric"
print(with(out.cv.glmnet,
           xyplot(y ~ x,data.frame(x = log10(lambda),y = cvm),type = "l",
                  col = "blue",xlab = lab,
                  ylab = "20-fold cross-validation \n classification error",
                  scales = list(y = list(limits = c(-0.02,0.45))),
                  panel = function(x, y, ...) {
                    panel.xyplot(x,y,...);
                    panel.abline(v = log10(lambda.opt),col = "orangered",
                                 lwd = 2,lty = "dotted");
                  }) +
           as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvm),
                           pch = 20,cex = 0.6,col = "blue")) +
           as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvup),
                           type = "l",col = "blue",lty = "solid")) +
           as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = cvlo),
                           type = "l",col = "blue",lty = "solid")) +
           as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),
                                            y = colMeans(abs(Y - y))),
                           type = "l",col = "darkorange",lwd = 2,
                           lty = "solid"))),
           split = c(1,1,2,2),more = TRUE)

# Plot number of non-zero regression coefficients.
print(with(out.cv.glmnet,
           xyplot(y ~ x,data.frame(x = log10(lambda),y = nzero),type = "l",
                  col = "blue",xlab = lab,
                  ylab = "number of non-zero \n coefficients",
                  panel = function(x, y, ...) {
                    panel.xyplot(x,y,...)
                    panel.abline(v = log10(lambda.opt),col = "orangered",
                                 lwd = 2,lty = "dotted")
                  }) +
           as.layer(xyplot(y ~ x,data.frame(x = log10(lambda),y = nzero),
                           pch = 20,cex = 0.6,col = "blue"))),
      split = c(1,2,2,2),more = FALSE)
&nbsp;

 

Fit variational approximation to posterior

Fit the fully-factorized variational approximation to the posterior distribution of the coefficients for a logistic regression model of the binary outcome (the type of leukemia), with spike-and-slab priors on the coefficients.

r <- system.time(fit.varbvs <- varbvs(X,NULL,y,"binomial",verbose = FALSE))
cat(sprintf("Model fitting took %0.2f seconds.\n",r["elapsed"]))
# Model fitting took 3.83 seconds.

Evaluate the varbvs predictions

Compute estimates of the disease outcome using the fitted model, and compare against the observed values.

y.varbvs <- predict(fit.varbvs,X,type = "class")
print(table(true = factor(y),pred = factor(y.varbvs)))
#     pred
# true  0  1
#    0 46  1
#    1  3 22

Visualize results of varbvs analysis

The first plot shows the classification error at each setting of the prior log-odds.

The second plot shows the evolution of the posterior mean regression coefficients (the beta’s) at different settings of the prior log-odds, for the top 6 variables ranked by posterior inclusion probability (averaged over settings of the hyperparameters).

The top-ranked variable (by posterior inclusion probability) has a much larger coefficient than all the others, so it is shown in a separate plot.

The third plot shows the (approximate) probability density of the prior log-odds parameter.

trellis.par.set(par.xlab.text = list(cex = 0.85),
                par.ylab.text = list(cex = 0.85),
                axis.text     = list(cex = 0.75))

# Get the normalized importance weights.
w <- fit.varbvs$w

# Plot classification error at each hyperparameter setting.
sigmoid10 <- function (x)
  1/(1 + 10^(-x))
logodds <- fit.varbvs$logodds
log10q  <- log10(sigmoid10(logodds))
m       <- length(logodds)
err     <- rep(0,m)
for (i in 1:m) {
  r      <- subset(fit.varbvs,logodds == logodds[i])
  ypred  <- predict(r,X)
  err[i] <- mean(y != ypred)
}
lab <- expression("more complex" %<-% paste(log[10],pi) %->% "less complex")
print(xyplot(y ~ x,data.frame(x = log10q,y = err),type = "l",
             col = "blue",xlab = lab,ylab = "classification error",
             scales = list(x = list(limits = c(-0.9,-3.65)))) +
      as.layer(xyplot(y ~ x,data.frame(x = log10q,y = err),
                      col = "blue",pch = 20,cex = 0.65)),
      split = c(1,1,2,2),more = TRUE)

# Plot expected number of included variables at each hyperparameter
# setting.
r <- colSums(fit.varbvs$alpha)
print(xyplot(y ~ x,data.frame(x = log10q,y = r),type = "l",col = "blue",
             xlab = lab,ylab = "expected number of\nincluded variables",
             scales = list(x = list(limits = c(-0.9,-3.65)),
                           y = list(log = 10,at = c(1,10,100)))) +
      as.layer(xyplot(y ~ x,data.frame(x = log10q,y = r),
                      col = "blue",pch = 20,cex = 0.65,
                      scales = list(x = list(limits = c(-0.9,-3.65)),
                                    y = list(log = 10)))),
      split = c(1,2,2,2),more = TRUE)

# Plot density of prior inclusion probability hyperparameter.
print(xyplot(y ~ x,data.frame(x = log10q,y = w),type = "l",col = "blue",
             xlab = lab,
             ylab = expression(paste("posterior probability of ",pi)),
             scales = list(x = list(limits = c(-0.9,-3.65)))) +
      as.layer(xyplot(y ~ x,data.frame(x = log10q,y = w),
                      col = "blue",pch = 20,cex = 0.65)),
      split = c(2,1,2,1),more = FALSE)
&nbsp;

 

References

Dettling, M. (2004). BagBoosting for tumor classification with gene expression data. Bioinformatics 20, 3583–3593.

Friedman, J., Hastie, T., Tibshirani, R. (2010) Regularization paths for generalized linear models via coordinate descent. Journal of Statistical Software 33, 1–22.

Session information

This is the version of R and the packages that were used to generate these results.

sessionInfo()
# R version 4.4.2 (2024-10-31)
# Platform: x86_64-pc-linux-gnu
# Running under: Ubuntu 24.04.1 LTS
# 
# Matrix products: default
# BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
# LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
# 
# locale:
#  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
#  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
# [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
# 
# time zone: Etc/UTC
# tzcode source: system (glibc)
# 
# attached base packages:
# [1] stats     graphics  grDevices utils     datasets  methods   base     
# 
# other attached packages:
# [1] glmnet_4.1-8        Matrix_1.7-1        latticeExtra_0.6-30
# [4] curl_6.0.0          varbvs_2.6-10       lattice_0.22-6     
# [7] rmarkdown_2.29     
# 
# loaded via a namespace (and not attached):
#  [1] cli_3.6.3          knitr_1.49         rlang_1.1.4        xfun_0.49         
#  [5] nor1mix_1.3-3      png_0.1-8          jsonlite_1.8.9     buildtools_1.0.0  
#  [9] htmltools_0.5.8.1  maketools_1.3.1    sys_3.4.3          sass_0.4.9        
# [13] grid_4.4.2         evaluate_1.0.1     jquerylib_0.1.4    fastmap_1.2.0     
# [17] foreach_1.5.2      interp_1.1-6       yaml_2.3.10        lifecycle_1.0.4   
# [21] compiler_4.4.2     codetools_0.2-20   RColorBrewer_1.1-3 Rcpp_1.0.13-1     
# [25] digest_0.6.37      R6_2.5.1           splines_4.4.2      shape_1.4.6.1     
# [29] bslib_0.8.0        tools_4.4.2        jpeg_0.1-10        iterators_1.0.14  
# [33] deldir_2.0-4       survival_3.7-0     cachem_1.1.0