## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE, fig.align = "center")

## ----packages, message=FALSE, warning=FALSE-----------------------------------
# Import packages
library("dplyr")
library("scater")
library("ggplot2")
library("scRNAseq")
library("Coralysis")
library("ComplexHeatmap")
library("SingleCellExperiment")

## ----import data, message=FALSE, warning=FALSE--------------------------------
# Import object
sce <- BachMammaryData()
colnames(sce) <- paste0("cell", 1:ncol(sce)) # create cell names

## ----Seurat normalisation,  message=FALSE, warning=FALSE----------------------
## Normalize the data
set.seed(123)
sce <- scater::logNormCounts(sce)
counts(sce) <- NULL

## ----hvg----------------------------------------------------------------------
# Feature selection with 'scran' package
nhvg <- 500
hvg <- scran::getTopHVGs(sce, n = nhvg)
hvg.idx <- which(row.names(sce) %in% hvg)
sce <- sce[hvg.idx, ]
row.names(sce) <- rowData(sce)$Symbol

## ----multi-level integration, warning=FALSE-----------------------------------
# Perform multi-level integration
set.seed(123)
sce <- RunParallelDivisiveICP(object = sce, L = 25, 
                              icp.batch.size = 1000,
                              build.train.set = FALSE,
                              threads = 2)

## ----integrated dimred--------------------------------------------------------
# Dimensional reduction - unintegrated
set.seed(123)
sce <- RunPCA(
    object = sce, assay.name = "joint.probability",
    return.model = TRUE
)

# UMAP
set.seed(123)
sce <- RunUMAP(
    object = sce, umap.method = "uwot",
    dims = 1:30, n_neighbors = 15,
    min_dist = 0.5, return.model = TRUE
)

## ----basal vs luminal cells, fig.width=7, fig.height=2.75---------------------
# Detect basal versus luminal cells
markers <- c("Krt5", "Krt18")
plots <- lapply(markers, function(x) {
    PlotExpression(sce,
        color.by = x, point.size = 0.2,
        point.stroke = 0.2, scale.values = TRUE
    )
})
cowplot::plot_grid(plotlist = plots, ncol = 2)

## ----cell state identification------------------------------------------------
# Summarise cell cluster probability
sce <- SummariseCellClusterProbability(object = sce, icp.round = 4) # save result in 'colData'
# colData(sce) # check the colData

## ----plot probability, fig.width=5, fig.height=4------------------------------
# Plot cell cluster probabilities - mean
# possible options: "mean_probs", "median_probs", "scaled_median_probs"
PlotExpression(
    object = sce, color.by = "scaled_mean_probs",
    color.scale = "viridis", point.size = 0.2,
    point.stroke = 0.1, legend.title = "Mean prob.\n(min-max)"
)

## ----icp clusters, fig.width=5, fig.height=6----------------------------------
# ICP clusters: identify the ICP probability table with the highest standard deviation
probs <- GetCellClusterProbability(object = sce, icp.round = 4, concatenate = FALSE)
probs.sd <- lapply(X = probs, FUN = function(x) {
    sd(x)
})
icp.run <- which.max(probs.sd) # 7
clt <- paste0("icp_run_round_", icp.run, "_4_clusters")
sce[[clt]] <- factor(sce[[clt]], levels = as.character(1:16))
PlotDimRed(sce,
    color.by = clt, point.size = 0.1,
    point.stroke = 0.1, legend.nrow = 5
)

## ----gene coefficients, message=FALSE, warning=FALSE, fig.width = 5.5, fig.height=14----
# Get gene coefficients
gene.coeff <- GetFeatureCoefficients(sce, icp.run = icp.run, icp.round = 4)
row.names(gene.coeff$icp_56) <- gene.coeff$icp_56$feature
gene.coeff$icp_56 <- gene.coeff$icp_56[, -1]

# Plot top positive coefficients in heatmap
positive.coeff <- (rowSums(gene.coeff$icp_56 > 0) > 0)
heat <- ComplexHeatmap::Heatmap(
    matrix = as.matrix(gene.coeff$icp_56)[positive.coeff, ],
    name = "Gene coef.",
    cluster_columns = FALSE, show_row_dend = FALSE, show_row_names = TRUE,
    row_names_gp = grid::gpar(fontsize = 7)
)
print(heat)

## ----gene coefficients per luminal cluster, fig.width=20, fig.height=10-------
# Check top gene coefficients per basal cluster: 3, 11, 12
gene.coeff.basal <- gene.coeff$icp_56[positive.coeff, ] %>%
    mutate("gene" = row.names(.)) 
top5.luminal <- top5.luminal.plts <- list()
for (i in as.character(c(3, 11, 12))) {
    clt.no <- paste0("clt", i)
    top5.luminal[[clt.no]] <- gene.coeff.basal %>%
        dplyr::select(all_of(c("gene", paste0("coeff_", clt.no)))) %>%
        arrange(desc(.data[[paste0("coeff_", clt.no)]])) %>%
        head(5) %>%
        pull(gene)
    top5.luminal.plts[[clt.no]] <- cowplot::plot_grid(plotlist = lapply(top5.luminal[[clt.no]], function(x) {
        PlotExpression(sce,
            color.by = x, point.size = 0.1, point.stroke = 0.1,
            scale.values = TRUE
        ) + ggplot2::ggtitle(gsub("clt", "cluster: ", clt.no))
    }), ncol = 5, align = "vh")
}
cowplot::plot_grid(plotlist = top5.luminal.plts, nrow = 3, align = "vh")

## ----graph-based clustering, fig.width=5, fig.height=6------------------------
### Graph-based clustering with scran
## Coralysis integrated PCA embedding
bluster.params <- bluster::SNNGraphParam(
    k = 30, cluster.fun = "louvain",
    cluster.args = list(resolution = 0.8)
)
set.seed(1024)
sce$Cluster <- scran::clusterCells(sce,
    use.dimred = "PCA",
    BLUSPARAM = bluster.params
)
PlotDimRed(sce,
    color.by = "Cluster", point.size = 0.2,
    point.stroke = 0.2, legend.nrow = 4
)

## ----cell cluster prob. bins, fig.width=9, fig.height=4.5---------------------
# cell states SCE object
sce.bins <- BinCellClusterProbability(sce, label = "Cluster", bins = 30)

# Project Coralysis bins onto single-cell UMAP
sce.bins <- ReferenceMapping(sce, sce.bins, ref.label = "Cluster", project.umap = TRUE)
umap.bins.labels <- PlotDimRed(sce.bins, color.by = "coral_labels")
umap.bins.probs <- PlotExpression(sce.bins,
    color.by = "aggregated_probability_bins",
    color.scale = "viridis",
    legend.title = "Prob. bins"
)
cowplot::plot_grid(umap.bins.labels, umap.bins.probs, ncol = 2, align = "vh")

## ----coralysis labels, fig.width=12, fig.height=6-----------------------------
# Coralysis labels: single cells vs bins
cowplot::plot_grid(
    PlotDimRed(sce,
        color.by = "Cluster", point.size = 0.2,
        point.stroke = 0.2, legend.nrow = 2
    ),
    umap.bins.labels
)

## ----coralysis probability, fig.width=12, fig.height=5------------------------
# Coralysis probability: single cells vs bins
cowplot::plot_grid(
    PlotExpression(
        object = sce, color.by = "scaled_mean_probs",
        color.scale = "viridis", point.size = 0.2,
        point.stroke = 0.1, legend.title = "Mean prob.\n(min-max)"
    ),
    umap.bins.probs
)

## ----differential expression, fig.width=10, fig.height=6----------------------
# Differential expression programs
corr.features <- CellBinsFeatureCorrelation(object = sce.bins, labels = "10")
top30.corr.clt10 <- corr.features %>%
    arrange(desc(abs(`10`))) %>%
    head(30)

# Heatmap of top 30 genes with the highest absolute Pearson correlation
pick.genes <- row.names(top30.corr.clt10)
mtx <- as.matrix(logcounts(sce.bins[pick.genes, paste0("10_bin", 1:30)]))
mtx <- t(scale(t(mtx)))
col_fun2 <- circlize::colorRamp2(c(0.8, 0.9, 1.0), scales::viridis_pal()(3))
heat.diff.exp <- Heatmap(
    matrix = mtx, name = "Row Z-score\n(normalized data)",
    cluster_rows = TRUE, cluster_columns = FALSE,
    show_column_dend = FALSE, show_row_dend = FALSE,
    use_raster = TRUE, show_column_names = TRUE,
    top_annotation = HeatmapAnnotation(
        "Probability" = colData(sce.bins[, paste0("10_bin", 1:30)])$aggregated_probability_bins,
        simple_anno_size = unit(0.25, "cm"),
        col = list("Probability" = col_fun2), show_annotation_name = FALSE,
        show_legend = TRUE,
        annotation_legend_param = list(Probability = list(
            direction = "horizontal",
            title_position = "topcenter"
        ))
    )
)
print(heat.diff.exp)

## ----plot expression, fig.width=12, fig.height=10-----------------------------
# Look into Cited1 expression
genes <- c("Glycam1", "Fabp3")
cowplot::plot_grid(
    PlotExpression(sce,
        color.by = genes[1], point.size = 0.5,
        point.stroke = 0.5, scale.values = TRUE
    ),
    PlotExpression(sce.bins,
        color.by = genes[1], point.size = 1,
        point.stroke = 1, scale.values = TRUE
    ),
    PlotExpression(sce,
        color.by = genes[2], point.size = 0.5,
        point.stroke = 0.5, scale.values = TRUE
    ),
    PlotExpression(sce.bins,
        color.by = genes[2], point.size = 1,
        point.stroke = 1, scale.values = TRUE
    ),
    
    ncol = 2
)

## ----rsession-----------------------------------------------------------------
# R session
sessionInfo()

