## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
    collapse = TRUE,
    comment = "#>"
)

## ----inst, eval=FALSE---------------------------------------------------------
# ## install if needed
# if (!requireNamespace("BiocManager", quietly = TRUE))
#     install.packages("BiocManager")
# 
# BiocManager::install("scConform")

## ----libraries, message=FALSE-------------------------------------------------
library(scConform)
library(SingleCellExperiment)
library(VGAM)
library(ontoProc)
library(MerfishData)
library(igraph)
library(scuttle)
library(scran)
library(BiocParallel)

`%notin%` <- Negate(`%in%`)

## ----data---------------------------------------------------------------------
# Load data
spe_baysor <- MouseIleumPetukhov2021(
    segmentation = "baysor",
    use.images = FALSE, use.polygons = FALSE
)
# Load ontology
cl <- getOnto("cellOnto", "2023")

## ----tags, echo=TRUE, fig.show='hide'-----------------------------------------
tags <- c(
    "CL:0009022", # Stromal
    "CL:0000236", # B cell
    "CL:0009080", # Tuft
    "CL:1000411", # Endothelial
    "CL:1000335", # Enterocyte
    "CL:1000326", # Goblet
    "CL:0002088", # ICC
    "CL:0009007", # Macrophage + DC
    "CL:1000343", # Paneth
    "CL:0000669", # Pericyte
    "CL:1000278", # Smooth Muscle
    "CL:0009017", # Stem + TA
    "CL:0000492", # T (CD4+)
    "CL:0000625", # T (CD8+)
    "CL:0017004" # Telocyte
)
opi <- graph_from_graphnel(onto_plot2(cl, tags))

## ----build-ontology-----------------------------------------------------------
## Delete CARO and BFO instances
sel_ver <- V(opi)$name[c(grep("CARO", V(opi)$name), grep("BFO", V(opi)$name))]
opi1 <- opi - sel_ver

## Rename vertex to match annotations
V(opi1)$name[grep("CL:0000236", V(opi1)$name)] <- "B cell"
V(opi1)$name[grep("CL:1000411", V(opi1)$name)] <- "Endothelial"
V(opi1)$name[grep("CL:1000335", V(opi1)$name)] <- "Enterocyte"
V(opi1)$name[grep("CL:1000326", V(opi1)$name)] <- "Goblet"
V(opi1)$name[grep("CL:0002088", V(opi1)$name)] <- "ICC"
V(opi1)$name[grep("CL:0009007", V(opi1)$name)] <- "Macrophage + DC"
V(opi1)$name[grep("CL:1000343", V(opi1)$name)] <- "Paneth"
V(opi1)$name[grep("CL:0000669", V(opi1)$name)] <- "Pericyte"
V(opi1)$name[grep("CL:1000278", V(opi1)$name)] <- "Smooth Muscle"
V(opi1)$name[grep("CL:0009017", V(opi1)$name)] <- "Stem + TA"
V(opi1)$name[grep("CL:0009022", V(opi1)$name)] <- "Stromal"
V(opi1)$name[grep("CL:0000492", V(opi1)$name)] <- "T (CD4+)"
V(opi1)$name[grep("CL:0000625", V(opi1)$name)] <- "T (CD8+)"
V(opi1)$name[grep("CL:0017004", V(opi1)$name)] <- "Telocyte"
V(opi1)$name[grep("CL:0009080", V(opi1)$name)] <- "Tuft"

## Add the edge from connective tissue cell and telocyte and delete redundant
## nodes
opi1 <- add_edges(opi1, c("connective\ntissue cell\nCL:0002320", "Telocyte"))
gr <- as_graphnel(opi1)
opi2 <- opi1 - c(
    "somatic\ncell\nCL:0002371", "contractile\ncell\nCL:0000183",
    "native\ncell\nCL:0000003"
)


V(opi2)$name <- trimws(gsub("CL:.*|\\n", " ", V(opi2)$name))

gr1 <- as_graphnel(opi2)

## Plot the final ontology
attrs <- list(node = list(shape = "box", fontsize = 15, fixedsize = FALSE))
plot(gr1, attrs = attrs)

## ----preprocess---------------------------------------------------------------
spe_baysor$cell_type <- spe_baysor$leiden_final
spe_baysor$cell_type[spe_baysor$cell_type %in% c(
    "B (Follicular, Circulating)",
    "B (Plasma)"
)] <- "B cell"
spe_baysor$cell_type[grep("Enterocyte", spe_baysor$cell_type)] <- "Enterocyte"
spe_baysor <- spe_baysor[, spe_baysor$cell_type %notin% c(
    "Removed",
    "Myenteric Plexus"
)]
spe_baysor

# See frequencies of cell types
table(spe_baysor$cell_type)

## ----split--------------------------------------------------------------------
set.seed(1636)
ref <- sample(seq_len(ncol(spe_baysor)), 600)

spe_ref <- spe_baysor[, ref]
spe_query <- spe_baysor[, -ref]
# Reference data
spe_ref
# Query data
spe_query

## ----model, warning=FALSE-----------------------------------------------------
# Randomly select 300 cells
set.seed(1704)
train <- sample(seq_len(ncol(spe_ref)), 300)
# Training data
spe_train <- spe_ref[, train]
spe_train

# get HVGs
spe_train <- logNormCounts(spe_train)
v <- modelGeneVar(spe_train)
hvg <- getTopHVGs(v, n = 50)

# Extract counts and convert data into a data.frame format
df_train <- as.data.frame(t(as.matrix(counts(spe_train[hvg, ]))))
df_train$Y <- spe_train$cell_type
table(df_train$Y)

# Fit multinomial model
fit <- vglm(Y ~ .,
    family = multinomial(refLevel = "B cell"),
    data = df_train
)

## ----predictions, warning=FALSE-----------------------------------------------
spe_cal <- spe_ref[, -train]
# Prediction matrix for calibration data
df_cal <- as.data.frame(t(as.matrix(counts(spe_cal[hvg, ]))))
p_cal <- predict(fit, newdata = df_cal, type = "response")
head(round(p_cal, 3))
# Prediction matrix for query data
df_test <- as.data.frame(t(as.matrix(counts(spe_query[hvg, ]))))
p_test <- predict(fit, newdata = df_test, type = "response")
head(round(p_test, 3))

## ----predsets-----------------------------------------------------------------
labels <- colnames(p_test)
sets <- getPredictionSets(
    x_query = p_test,
    x_cal = p_cal,
    y_cal = spe_cal$cell_type,
    alpha = 0.1,
    follow_ontology = FALSE,
    labels = labels
)

# See the first six prediction sets
sets[1:6]

## ----cvg----------------------------------------------------------------------
# Check coverage
cvg <- rep(NA, length(sets))
for (i in seq_len(length(sets))) {
    cvg[i] <- spe_query$cell_type[i] %in% sets[[i]]
}
mean(cvg)

## ----predsets-sc--------------------------------------------------------------
# Retrieve labels as leaf nodes of the ontology
labels <- V(opi2)$name[degree(opi2, mode = "out") == 0]

# Create corresponding colData
for (i in labels) {
    colData(spe_cal)[[i]] <- p_cal[, i]
    colData(spe_query)[[i]] <- p_test[, i]
}

# Create prediction sets
spe_query <- getPredictionSets(
    x_query = spe_query,
    x_cal = spe_cal,
    y_cal = spe_cal$cell_type,
    alpha = 0.1,
    follow_ontology = FALSE,
    pr_name = "pred_set",
    labels = labels
)

# See the new variable pred_set into the colData
head(colData(spe_query))

## ----predset-hier-------------------------------------------------------------
spe_query <- getPredictionSets(
    x_query = spe_query,
    x_cal = spe_cal,
    y_cal = spe_cal$cell_type,
    onto = opi2,
    alpha = 0.1,
    follow_ontology = TRUE,
    method = "full",
    pr_name = "pred_set_hier",
    BPPARAM = MulticoreParam(workers = 2),
    simplify = FALSE
)

# See the first six prediction sets
head(spe_query$pred_set_hier)

## ----cvg1---------------------------------------------------------------------
# Check coverage
cvg1 <- rep(NA, length(spe_query$pred_set_hier))
for (i in seq_len(length(spe_query$pred_set_hier))) {
    cvg1[i] <- spe_query$cell_type[i] %in% spe_query$pred_set_hier[[i]]
}

mean(cvg1)

## ----common-ancestor----------------------------------------------------------
spe_query$pred_set_hier_simp <- vapply(
    spe_query$pred_set_hier,
    function(x) getCommonAncestor(x, opi2),
    character(1)
)
head(spe_query$pred_set_hier_simp)

## ----plotres------------------------------------------------------------------
plotResult(spe_query$pred_set[[75]], opi2,
    col_grad = "pink", attrs = attrs, add_scores = FALSE,
    title = "Conformal Prediction set"
)

## ----plotres2-----------------------------------------------------------------
plotResult(spe_query$pred_set_hier[[75]], opi2,
    col_grad = "pink", attrs = attrs, add_scores = FALSE,
    title = "Hierarchical prediction set"
)

## ----plotres3-----------------------------------------------------------------
plotResult(spe_query$pred_set[[75]], opi2,
    probs = p_test[75, ],
    col_grad = c("lemonchiffon", "orange", "darkred"),
    attrs = attrs, add_scores = TRUE,
    title = "Conformal Prediction set"
)

## ----plotres4-----------------------------------------------------------------
plotResult(spe_query$pred_set_hier[[75]], opi2,
    probs = p_test[75, ],
    col_grad = c("lemonchiffon", "orange", "darkred"),
    attrs = attrs, add_scores = TRUE,
    title = "Hierarchical Prediction set"
)

## ----SessionInfo, echo=FALSE, message=FALSE, warning=FALSE, comment=NA--------
sessionInfo()

