## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
    dpi = 70, echo = TRUE, warning = FALSE, message = FALSE,
    eval = TRUE, fig.show = TRUE, fig.width = 6,
    fig.height = 4, fig.align = "center",
    out.width = "60%", cache = FALSE
)

## -----------------------------------------------------------------------------
# CRAN
library(pheatmap)
library(vegan)
library(gridExtra)

# Bioconductor
library(mixOmics)
library(Biobase)
library(SummarizedExperiment)
library(PLSDAbatch)

# print package versions
package.version("pheatmap")
package.version("vegan")
package.version("gridExtra")
package.version("mixOmics")
package.version("Biobase")
package.version("PLSDAbatch")

## -----------------------------------------------------------------------------
# AD data
data("AD_data")
ad.count <- assays(AD_data$FullData)$Count
dim(ad.count)

ad.metadata <- rowData(AD_data$FullData)
ad.batch <- factor(ad.metadata$sequencing_run_date,
    levels = unique(ad.metadata$sequencing_run_date)
)
ad.trt <- as.factor(ad.metadata$initial_phenol_concentration.regroup)
names(ad.batch) <- names(ad.trt) <- rownames(ad.metadata)

## -----------------------------------------------------------------------------
ad.filter.res <- PreFL(data = ad.count)
ad.filter <- ad.filter.res$data.filter
dim(ad.filter)

# zero proportion before filtering
ad.filter.res$zero.prob.before
# zero proportion after filtering
ad.filter.res$zero.prob.after

## -----------------------------------------------------------------------------
ad.clr <- logratio.transfo(X = ad.filter, logratio = "CLR", offset = 1)
class(ad.clr) <- "matrix"

## ----ADpcaBefore, out.width = '80%', fig.align = 'center', fig.cap = 'PCA sample plot with density overlays for the AD data.'----
# AD data
ad.pca.before <- mixOmics::pca(ad.clr, ncomp = 3, scale = TRUE)

Scatter_Density(
    components = ad.pca.before$variates$X, comp = c(1, 2),
    expl.var = ad.pca.before$prop_expl_var$X,
    batch = ad.batch, trt = ad.trt,
    title = "AD data", trt.legend.title = "Phenol conc."
)

## ----ADboxBefore, out.width = '80%', fig.align = 'center', fig.cap = 'Boxplots of sample values for OTU28 before batch-effect correction in the AD data.'----
ad.OTU.name <- selectVar(ad.pca.before, comp = 1)$name[1]
ad.OTU_batch <- data.frame(value = ad.clr[, ad.OTU.name], batch = ad.batch)
box_plot(
    df = ad.OTU_batch, title = paste(ad.OTU.name, "(AD data)"),
    x.angle = 30
)

## ----ADdensityBefore, out.width = '80%', fig.align = 'center', fig.cap = 'Density plots of sample values for OTU28 before batch-effect correction in the AD data.'----
density_plot(df = ad.OTU_batch, title = paste(ad.OTU.name, "(AD data)"))

## -----------------------------------------------------------------------------
# reference batch: 14/04/2016
ad.batch <- relevel(x = ad.batch, ref = "14/04/2016")

ad.OTU.lm <- linear_regres(
    data = ad.clr[, ad.OTU.name],
    trt = ad.trt, batch.fix = ad.batch,
    type = "linear model"
)
summary(ad.OTU.lm$model$data)

# reference batch: 21/09/2017
ad.batch <- relevel(x = ad.batch, ref = "21/09/2017")

ad.OTU.lm <- linear_regres(
    data = ad.clr[, ad.OTU.name],
    trt = ad.trt, batch.fix = ad.batch,
    type = "linear model"
)
summary(ad.OTU.lm$model$data)

## ----ADheatmap, out.width = '90%', fig.align = 'center', fig.cap = 'Hierarchical clustering of samples in the AD data.'----
# scale the clr data on both OTUs and samples
ad.clr.s <- scale(ad.clr, center = TRUE, scale = TRUE)
ad.clr.ss <- scale(t(ad.clr.s), center = TRUE, scale = TRUE)

ad.anno_col <- data.frame(Batch = ad.batch, Treatment = ad.trt)
ad.anno_colors <- list(
    Batch = color.mixo(seq_len(5)),
    Treatment = pb_color(seq_len(2))
)
names(ad.anno_colors$Batch) <- levels(ad.batch)
names(ad.anno_colors$Treatment) <- levels(ad.trt)

pheatmap(ad.clr.ss,
    cluster_rows = FALSE,
    fontsize_row = 4,
    fontsize_col = 6,
    fontsize = 8,
    clustering_distance_rows = "euclidean",
    clustering_method = "ward.D",
    treeheight_row = 30,
    annotation_col = ad.anno_col,
    annotation_colors = ad.anno_colors,
    border_color = "NA",
    main = "AD data - Scaled"
)

## -----------------------------------------------------------------------------
# AD data
ad.factors.df <- data.frame(trt = ad.trt, batch = ad.batch)
class(ad.clr) <- "matrix"
ad.rda.before <- varpart(ad.clr, ~trt, ~batch,
    data = ad.factors.df, scale = TRUE
)
ad.rda.before$part$indfract

## -----------------------------------------------------------------------------
# the optimal number of treatment components
nlevels(ad.trt) - 1

# the optimal number of batch components
nlevels(ad.batch) - 1

ad.PLSDA_batch.res.reg <- PLSDA_batch(
    X = ad.clr,
    Y.trt = ad.trt, Y.bat = ad.batch,
    ncomp.trt = 1, ncomp.bat = 4,
    mode = "regression"
)
ad.PLSDA_batch.reg <- ad.PLSDA_batch.res.reg$X.nobatch

## -----------------------------------------------------------------------------
# estimate the number of treatment components
ad.trt.mat <- unmap(ad.trt)
ad.trt.tune.cnc <- pls(
    X = ad.clr, Y = ad.trt.mat,
    ncomp = 2, mode = "canonical"
)
ad.trt.tune.cnc$prop_expl_var$Y # 1

## -----------------------------------------------------------------------------
# estimate the number of batch components
ad.batch.tune.cnc <- PLSDA_batch(
    X = ad.clr,
    Y.trt = ad.trt, Y.bat = ad.batch,
    ncomp.trt = 1, ncomp.bat = 6,
    mode = "canonical"
)
ad.batch.tune.cnc$explained_variance.bat$Y # 4
sum(ad.batch.tune.cnc$explained_variance.bat$Y[seq_len(4)])

## -----------------------------------------------------------------------------
# the optimal number of treatment components
nlevels(ad.trt) - 1

# the optimal number of batch components
nlevels(ad.batch) - 1

ad.PLSDA_batch.res.cnc <- PLSDA_batch(
    X = ad.clr,
    Y.trt = ad.trt, Y.bat = ad.batch,
    ncomp.trt = 1, ncomp.bat = 4,
    mode = "canonical"
)
ad.PLSDA_batch.cnc <- ad.PLSDA_batch.res.cnc$X.nobatch

## ----eval = F-----------------------------------------------------------------
# # estimate the number of variables to select per treatment component
# ad.test.keepX <- c(
#     seq(1, 10, 1), seq(20, 100, 10),
#     seq(150, 231, 50), 231
# )
# ad.trt.tune.v <- tune.splsda(
#     X = ad.clr, Y = ad.trt,
#     ncomp = 1, test.keepX = ad.test.keepX,
#     folds = 4, nrepeat = 50, seed = 777
# )
# ad.trt.tune.v$choice.keepX # 100

## -----------------------------------------------------------------------------
ad.sPLSDA_batch.res.reg <- PLSDA_batch(
    X = ad.clr,
    Y.trt = ad.trt, Y.bat = ad.batch,
    ncomp.trt = 1, keepX.trt = 100,
    ncomp.bat = 4, mode = "regression"
)
ad.sPLSDA_batch.reg <- ad.sPLSDA_batch.res.reg$X.nobatch

## -----------------------------------------------------------------------------
ad.sPLSDA_batch.res.cnc <- PLSDA_batch(
    X = ad.clr,
    Y.trt = ad.trt, Y.bat = ad.batch,
    ncomp.trt = 1, keepX.trt = 100,
    ncomp.bat = 4, mode = "canonical"
)
ad.sPLSDA_batch.cnc <- ad.sPLSDA_batch.res.cnc$X.nobatch

## -----------------------------------------------------------------------------
ad.pca.before <- pca(ad.clr, ncomp = 3, scale = TRUE)
ad.pca.PLSDA_batch.cnc <- pca(ad.PLSDA_batch.cnc, ncomp = 3, scale = TRUE)
ad.pca.sPLSDA_batch.cnc <- pca(ad.sPLSDA_batch.cnc, ncomp = 3, scale = TRUE)
ad.pca.PLSDA_batch.reg <- pca(ad.PLSDA_batch.cnc, ncomp = 3, scale = TRUE)
ad.pca.sPLSDA_batch.reg <- pca(ad.sPLSDA_batch.cnc, ncomp = 3, scale = TRUE)

## ----fig.show='hide'----------------------------------------------------------
# order batches
ad.batch <- factor(ad.metadata$sequencing_run_date,
    levels = unique(ad.metadata$sequencing_run_date)
)

ad.pca.before.plot <- Scatter_Density(
    components = ad.pca.before$variates$X, comp = c(1, 2),
    expl.var = ad.pca.before$prop_expl_var$X,
    batch = ad.batch,
    trt = ad.trt,
    title = "Before correction"
)
ad.pca.PLSDA_batch.cnc.plot <- Scatter_Density(
    components = ad.pca.PLSDA_batch.cnc$variates$X, comp = c(1, 2),
    expl.var = ad.pca.PLSDA_batch.cnc$prop_expl_var$X,
    batch = ad.batch,
    trt = ad.trt,
    title = "PLSDA-batch (canonical)"
)
ad.pca.sPLSDA_batch.cnc.plot <- Scatter_Density(
    components = ad.pca.sPLSDA_batch.cnc$variates$X, comp = c(1, 2),
    expl.var = ad.pca.sPLSDA_batch.cnc$prop_expl_var$X,
    batch = ad.batch,
    trt = ad.trt,
    title = "sPLSDA-batch (canonical)"
)
ad.pca.PLSDA_batch.reg.plot <- Scatter_Density(
    components = ad.pca.PLSDA_batch.reg$variates$X, comp = c(1, 2),
    expl.var = ad.pca.PLSDA_batch.reg$prop_expl_var$X,
    batch = ad.batch,
    trt = ad.trt,
    title = "PLSDA-batch (regression)"
)
ad.pca.sPLSDA_batch.reg.plot <- Scatter_Density(
    components = ad.pca.sPLSDA_batch.reg$variates$X, comp = c(1, 2),
    expl.var = ad.pca.sPLSDA_batch.reg$prop_expl_var$X,
    batch = ad.batch,
    trt = ad.trt,
    title = "sPLSDA-batch (regression)"
)

## ----ADpca, fig.height = 15, fig.width = 10, out.width = '100%', echo = FALSE, fig.align = 'center', fig.cap = 'PCA sample plots with density overlays before and after batch-effect correction in the AD data.'----
grid.arrange(ad.pca.before.plot,
    ad.pca.PLSDA_batch.cnc.plot,
    ad.pca.sPLSDA_batch.cnc.plot,
    ad.pca.PLSDA_batch.reg.plot,
    ad.pca.sPLSDA_batch.reg.plot,
    ncol = 2
)

## ----ADprda, fig.height = 6, fig.width = 4, out.width = '80%', fig.align = 'center', fig.cap = 'Global explained variance before and after batch effect correction for the AD data.'----
# AD data
ad.corrected.list <- list(
    `Before correction` = ad.clr,
    `PLSDA-batch (canonical)` = ad.PLSDA_batch.cnc,
    `sPLSDA-batch (canonical)` = ad.sPLSDA_batch.cnc,
    `PLSDA-batch (regression)` = ad.PLSDA_batch.reg,
    `sPLSDA-batch (regression)` = ad.sPLSDA_batch.reg
)

ad.prop.df <- data.frame(
    Treatment = NA, Batch = NA,
    Intersection = NA,
    Residuals = NA
)
for (i in seq_len(length(ad.corrected.list))) {
    rda.res <- varpart(ad.corrected.list[[i]], ~trt, ~batch,
        data = ad.factors.df, scale = TRUE
    )
    ad.prop.df[i, ] <- rda.res$part$indfract$Adj.R.squared
}

rownames(ad.prop.df) <- names(ad.corrected.list)

ad.prop.df <- ad.prop.df[, c(1, 3, 2, 4)]

partVar_plot(prop.df = ad.prop.df)

## ----ADr21, fig.height = 6, fig.width = 8, out.width = '100%', fig.align = 'center', fig.cap = 'AD study: $R^2$ values for each microbial variable before and after batch effect correction.'----
# AD data
# scale
ad.corr_scale.list <- lapply(
    ad.corrected.list,
    function(x) {
        apply(x, 2, scale)
    }
)

ad.r_values.list <- list()
for (i in seq_len(length(ad.corr_scale.list))) {
    ad.r_values <- data.frame(trt = NA, batch = NA)
    for (c in seq_len(ncol(ad.corr_scale.list[[i]]))) {
        ad.fit.res.trt <- lm(ad.corr_scale.list[[i]][, c] ~ ad.trt)
        ad.r_values[c, 1] <- summary(ad.fit.res.trt)$r.squared
        ad.fit.res.batch <- lm(ad.corr_scale.list[[i]][, c] ~ ad.batch)
        ad.r_values[c, 2] <- summary(ad.fit.res.batch)$r.squared
    }
    ad.r_values.list[[i]] <- ad.r_values
}
names(ad.r_values.list) <- names(ad.corr_scale.list)

ad.boxp.list <- list()
for (i in seq_len(length(ad.r_values.list))) {
    ad.boxp.list[[i]] <-
        data.frame(
            r2 = c(
                ad.r_values.list[[i]][, "trt"],
                ad.r_values.list[[i]][, "batch"]
            ),
            Effects = as.factor(rep(c("Treatment", "Batch"),
                each = 231
            ))
        )
}
names(ad.boxp.list) <- names(ad.r_values.list)

ad.r2.boxp <- rbind(
    ad.boxp.list$`Before correction`,
    ad.boxp.list$`PLSDA-batch (canonical)`,
    ad.boxp.list$`sPLSDA-batch (canonical)`,
    ad.boxp.list$`PLSDA-batch (regression)`,
    ad.boxp.list$`sPLSDA-batch (regression)`
)

ad.r2.boxp$methods <- rep(
    c(
        "Before correction", "PLSDA-batch (canonical)",
        "sPLSDA-batch (canonical)", "PLSDA-batch (regression)",
        "sPLSDA-batch (regression)"
    ),
    each = 462
)

ad.r2.boxp$methods <- factor(ad.r2.boxp$methods,
    levels = unique(ad.r2.boxp$methods)
)

ggplot(ad.r2.boxp, aes(x = Effects, y = r2, fill = Effects)) +
    geom_boxplot(alpha = 0.80) +
    theme_bw() +
    theme(
        text = element_text(size = 18),
        axis.title.x = element_blank(),
        axis.title.y = element_blank(),
        axis.text.x = element_text(angle = 60, hjust = 1, size = 18),
        axis.text.y = element_text(size = 18),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.x = element_blank(),
        legend.position = "right"
    ) +
    facet_grid(~methods) +
    scale_fill_manual(values = pb_color(c(12, 14)))

## ----ADr22, fig.height = 6, fig.width = 8, out.width = '100%', fig.align = 'center', fig.cap = 'AD study: Sum of $R^2$ values for each microbial variable before and after batch effect correction.'----
##################################
ad.barp.list <- list()
for (i in seq_len(length(ad.r_values.list))) {
    ad.barp.list[[i]] <- data.frame(
        r2 = c(
            sum(ad.r_values.list[[i]][, "trt"]),
            sum(ad.r_values.list[[i]][, "batch"])
        ),
        Effects = c("Treatment", "Batch")
    )
}
names(ad.barp.list) <- names(ad.r_values.list)

ad.r2.barp <- rbind(
    ad.barp.list$`Before correction`,
    ad.barp.list$`PLSDA-batch (canonical)`,
    ad.barp.list$`sPLSDA-batch (canonical)`,
    ad.barp.list$`PLSDA-batch (regression)`,
    ad.barp.list$`sPLSDA-batch (regression)`
)


ad.r2.barp$methods <- rep(
    c(
        "Before correction", "PLSDA-batch (canonical)",
        "sPLSDA-batch (canonical)", "PLSDA-batch (regression)",
        "sPLSDA-batch (regression)"
    ),
    each = 2
)

ad.r2.barp$methods <- factor(ad.r2.barp$methods,
    levels = unique(ad.r2.barp$methods)
)


ggplot(ad.r2.barp, aes(x = Effects, y = r2, fill = Effects)) +
    geom_bar(stat = "identity") +
    theme_bw() +
    theme(
        text = element_text(size = 18),
        axis.title.x = element_blank(),
        axis.title.y = element_blank(),
        axis.text.x = element_text(angle = 60, hjust = 1, size = 18),
        axis.text.y = element_text(size = 18),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.x = element_blank(),
        legend.position = "right"
    ) +
    facet_grid(~methods) +
    scale_fill_manual(values = pb_color(c(12, 14)))

## ----ADalignment, fig.height = 3, out.width = '90%', fig.align = 'center', fig.cap = 'Comparison of alignment scores before and after batch effect correction using different methods for the AD data.'----
# AD data
ad.scores <- c()
names(ad.batch) <- rownames(ad.clr)
for (i in seq_len(length(ad.corrected.list))) {
    res <- alignment_score(
        data = ad.corrected.list[[i]],
        batch = ad.batch,
        var = 0.95,
        k = 8,
        ncomp = 50
    )
    ad.scores <- c(ad.scores, res)
}

ad.scores.df <- data.frame(
    scores = ad.scores,
    methods = names(ad.corrected.list)
)

ad.scores.df$methods <- factor(ad.scores.df$methods,
    levels = rev(names(ad.corrected.list))
)


ggplot() +
    geom_col(aes(
        x = ad.scores.df$methods,
        y = ad.scores.df$scores
    )) +
    geom_text(
        aes(
            x = ad.scores.df$methods,
            y = ad.scores.df$scores / 2,
            label = round(ad.scores.df$scores, 3)
        ),
        size = 3, col = "white"
    ) +
    coord_flip() +
    theme_bw() +
    ylab("Alignment Scores") +
    xlab("") +
    ylim(0, 0.85)

## -----------------------------------------------------------------------------
splsda.plsda_batch.reg <- splsda(
    X = ad.PLSDA_batch.reg, Y = ad.trt,
    ncomp = 3, keepX = rep(50, 3)
)
select.plsda_batch.reg <- selectVar(splsda.plsda_batch.reg, comp = 1)
head(select.plsda_batch.reg$value)

splsda.splsda_batch.reg <- splsda(
    X = ad.sPLSDA_batch.reg, Y = ad.trt,
    ncomp = 3, keepX = rep(50, 3)
)
select.splsda_batch.reg <- selectVar(splsda.splsda_batch.reg, comp = 1)
head(select.splsda_batch.reg$value)

length(intersect(select.plsda_batch.reg$name, select.splsda_batch.reg$name))

## -----------------------------------------------------------------------------
sessionInfo()

