## ----installation, eval=FALSE-------------------------------------------------
# if (!require("BiocManager", quietly = TRUE))
#     install.packages("BiocManager")
# BiocManager::install("SETA")

## ----load libraries, message=FALSE, warning=FALSE, echo=TRUE------------------
library(SingleCellExperiment)
library(SETA)
library(ggplot2)
library(dplyr)
library(tidyr)
library(corrplot)
library(caret)
have_ml <- all(vapply(c("caret","glmnet"), requireNamespace, logical(1), quietly = TRUE))
if (!have_ml) {
    warning("Some machine learning packages are not installed. ML chunks will be skipped.")
}
library(TabulaMurisSenisData)

## ----load data, message=FALSE, warning=FALSE, echo = FALSE--------------------
sce <- TabulaMurisSenisDroplet(tissues = "Lung")$Lung

sce <- sce[, colData(sce)$subtissue != "immune-endo-depleted"]

## ----tabular data exploration-------------------------------------------------
table(sce$free_annotation, sce$age)

## ----palettes-----------------------------------------------------------------
# Set up a color palette for plots
# 3-group distinct categoricals
age_palette <- c(
    "1m" = "#90EE90",
    "3m" = "#4CBB17",
    "18m" = "#228B22",
    "21m" = "#355E3B",
    "30m" = "#023020")

# continuous palette of similar look
c_palette <- colorRampPalette(c("#3B9AB2", "#78B7C5",
                                "#EBCC2A", "#E1AF00",
                                "#F21A00"))(100)

## ----setaCounts---------------------------------------------------------------
df <- data.frame(colData(sce))

df$mouse.id <- gsub("/", "", df$mouse.id)

taxa_counts <- setaCounts(
    df,
    cell_type_col = "free_annotation",
    sample_col = "mouse.id",
    bc_col = "cell")

taxa_counts[1:5, 1:5]

## ----setaMetadata-------------------------------------------------------------
meta_df <- setaMetadata(
    df,
    sample_col = "mouse.id",
    meta_cols = c("age", "sex"))

meta_df[1:5, ]

## ----setaTransform------------------------------------------------------------
clr_transformed <- setaTransform(taxa_counts, method = "CLR")

## ----setaDistances------------------------------------------------------------
dist_df <- setaDistances(clr_transformed)

# Merge metadata
merged_dist <- dist_df |>
    left_join(meta_df, by = c("from" = "sample_id")) |>
    left_join(meta_df, by = c("to" = "sample_id"), suffix = c(".from", ".to"))

# Create a age-age category for comparison
merged_dist$age_pair <- paste(merged_dist$age.from,
                            merged_dist$age.to,
                            sep = "-")

## ----viz distances, fig.width = 9, fig.height = 6-----------------------------
ggplot(merged_dist, aes(x = age_pair, y = distance)) +
    geom_boxplot(fill = "grey90") +
    geom_jitter(width = 0.2, color = "black") +
    labs(title = "Aitchison Distances Between Age Groups",
        x = "Age Pair",
        y = "Aitchison Distance") +
    theme_minimal(base_size = 16) + 
    theme(axis.text.x = element_text(angle = 45, vjust = 0.5))

## ----preprocess for comparisons-----------------------------------------------
clr_long <- as.data.frame(clr_transformed$counts)
colnames(clr_long) <- c("sample", "Celltype", "CLR")
clr_long <- clr_long |>
    left_join(meta_df, by = c("sample" = "sample_id"))

## ----viz pairwise distances, fig.width = 15, fig.height = 10------------------
# Apply pairwise Wilcoxon tests and plot using ggpubr
ggplot(clr_long, aes(x = age, y = CLR,
                    fill = age, color = age)) +
    geom_boxplot(position = position_dodge(0.8), alpha = 0.7) +
    geom_jitter(size = 1.5, shape = 21) +
    # stat_compare_means(method = "wilcox.test",
    #                    label = "p.signif",
    #                    comparisons = list(c("normal", "influenza"),
    #                                       c("normal", "COVID-19"),
    #                                       c("influenza", "COVID-19")),
    #                    position = position_dodge(0.8)) +
    facet_wrap(~ Celltype) +
    theme_minimal(base_size = 12) +
    scale_fill_manual(values = age_palette) +
    scale_color_manual(values = age_palette) +
    theme(axis.text.x = element_text(angle = 45, hjust = 0.5, vjust = 1)) +
    labs(title = "CLR by Celltype and Age",
        x = "Age",
        y = "CLR-transformed Composition")

## ----metadata corr------------------------------------------------------------
clr_df <- clr_transformed$counts |>
    data.frame() |> # as.data.frame converts it to long form
    pivot_wider(names_from='Var2',
                values_from = "Freq") |>
    rename(`mouse.id` = Var1)

clr_metadata <- clr_df |>
    left_join(meta_df, by = c("mouse.id" = "sample_id"))

clr_data <- clr_metadata |> select(where(is.numeric))

# One-hot encode 'age' and clean column names
oh <- model.matrix(~age - 1, data = clr_metadata)
colnames(oh) <- sub("^age", "", colnames(oh))
rownames(clr_data) <- clr_metadata$mouse.id


# Combine CLR data and metadata
combined <- cbind(clr_data, oh)

# Compute full correlation matrix
full_cor_mat <- cor(combined, method = "pearson")
p_mat <- cor.mtest(full_cor_mat)$p

## ----corr viz, fig.width = 12, fig.height = 12--------------------------------
corrplot(full_cor_mat,
        method = "circle",
        type = "full",
        addrect = 4,
        col = c_palette,
        p.mat = p_mat,
        sig.level = c(.001, .01, .05),
        insig = "label_sig",
        tl.cex = 0.8,
        pch.cex = 1.5,
        tl.col = "black",
        order = "hclust",
        diag = FALSE)

## ----caret, warning=FALSE, eval = have_ml-------------------------------------
set.seed(687)
train_df <- clr_metadata |>
    select(-`mouse.id`) |>
    mutate(age = factor(age))

train_control <- trainControl(method = "cv", number = 5)

model <- train(age ~ ., data = subset(train_df),
                method = "glmnet",
                trControl = train_control)

importance <- varImp(model)

## ----plot caret, fig.width = 12, fig.height = 12, eval = have_ml--------------
plot(importance,
    main = "Cross-Validated Variable Importance",
    sub = "Caret GLMnet model"
)

## ----packages used------------------------------------------------------------
sessionInfo()

