## ----echo=FALSE, message=FALSE, warning = FALSE-------------------------------
knitr::opts_chunk$set(error=FALSE, message=FALSE, 
                      warning=FALSE, collapse = TRUE)
library(BiocStyle)

## ----message=FALSE, warning = FALSE-------------------------------------------
library(poem)
library(ggplot2)
library(cowplot)
library(SpatialExperiment)
library(STexampleData)
library(dplyr)
library(tidyr)

## -----------------------------------------------------------------------------
my_cols <-c("#D55E00", "#CC79A7","#E69F00","#0072B2","#009E73","#F0E442",
            "#56B4E9","#000000")
names(my_cols) <- as.character(seq(my_cols))

## -----------------------------------------------------------------------------
spe <- Visium_humanDLPFC()
spe <- spe[, !is.na(colData(spe)$reference)]
spe

## -----------------------------------------------------------------------------
data <- data.frame(spatialCoords(spe))
data$reference <- colData(spe)$reference
data <- na.omit(data)
data$reference <- factor(data$reference, levels=c("WM", "Layer6", "Layer5", 
                                                  "Layer4", "Layer3", "Layer2", 
                                                  "Layer1"))

## ----fig.height = 4, fig.width = 4, fig.small = TRUE--------------------------
p1 <- ggplot(data) +
  geom_point(aes(x = pxl_col_in_fullres, y = -pxl_row_in_fullres, 
                 color = reference), size=0.3) +
  labs(x = "", y = "", color="", title="Manual annotation") +
  theme_minimal() +
  scale_color_manual(values = unname(my_cols)) +
  theme(
  legend.box.background = element_rect(fill = "grey90", color = "black", 
                                       size = 0.1),
  legend.box.margin = margin(-1, -1, -1, -1),
  axis.title.x=element_blank(),
  legend.position = "bottom",
  legend.box.spacing = margin(0),
  axis.text.x=element_blank(),
  axis.ticks.x=element_blank(),
  axis.text.y=element_blank(),
  axis.ticks.y=element_blank(),
  panel.spacing.x = unit(-0.5, "cm"),
  panel.grid.major = element_blank(), 
  panel.grid.minor = element_blank(),
  plot.title = element_text(hjust = 0.5, size=12, 
                            margin = margin(b = 5, t = 15))) +
  guides(color = guide_legend(keywidth = 1, keyheight = 0.8, 
                              override.aes = list(size = 3)))

p1

## -----------------------------------------------------------------------------
set.seed(123) # For reproducibility

# Given a factor vector representing clustering results, simulate clustering 
# variations including merging two clusters and adding random noise.
simulate_clustering_variation <- function(clusters, split_cluster = NULL, 
                                          merge_clusters = NULL, 
                                          noise_level = 0.1) {
  # Convert to numeric for easier manipulation
  merge_clusters <- which(levels(clusters) %in% merge_clusters)
  clusters <- as.numeric(clusters)

  # 1. Merging two clusters
  if (!is.null(merge_clusters)) {
    clusters[clusters %in% merge_clusters] <- merge_clusters[1] 
    # Rename both to the same label
  }
  
  # 2. Adding random noise
  n <- length(clusters)
  n_noise <- round(n * noise_level) # Number of elements to replace
  if (n_noise > 0) {
    noise_indices <- sample(seq_len(n), n_noise) # Random indices to replace
    existing_levels <- unique(clusters)
    clusters[noise_indices] <- sample(existing_levels, n_noise, replace = TRUE) 
    # Replace with random levels
  }
  
  # Convert back to factor and return
  factor(clusters)
}

## -----------------------------------------------------------------------------
# P1: add random noise
data$P1 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.1
)

# P2: split Layer 3 into 2 domains, add random noise
data$P2 <- as.numeric(data$reference)
data$P2[data$reference=="Layer3" & data$pxl_col_in_fullres < 8000] <- 8
data$P2 <- factor(as.numeric(factor(data$P2)))

data$P2 <- simulate_clustering_variation(
  data$P2,
  noise_level = 0.1
)

# P3: merge Layer 4 and Layer 5, add random noise
data$P3 <- simulate_clustering_variation(
  data$reference,
  merge_clusters = c("Layer4", "Layer5"),
  noise_level = 0.1
)


## ----fig.height = 7, fig.width = 7--------------------------------------------
p2 <- data %>% pivot_longer(cols=-c("pxl_col_in_fullres","pxl_row_in_fullres"), 
                      names_to="prediction", values_to="domain") %>% 
  dplyr::filter(prediction != "reference") %>%
  ggplot() +
  geom_point(aes(x = pxl_col_in_fullres, y = -pxl_row_in_fullres, 
                 color = domain), size=0.4) +
  facet_wrap(~prediction, nrow=2) +
  labs(x = "", y = "", color="", title="") +
  theme_minimal() +
  scale_color_manual(values = unname(my_cols)) +
  theme(
  legend.box.background = element_rect(fill = "grey90", 
                                       color = "black", size = 0.1),
  legend.box.margin = margin(-1, -1, -1, -1),
  axis.title.x=element_blank(),
  legend.position = "bottom",
  legend.justification=c(0, 0),
  legend.box.spacing = margin(0),
  axis.text.x=element_blank(),
  axis.ticks.x=element_blank(),
  axis.text.y=element_blank(),
  axis.ticks.y=element_blank(),
  panel.spacing.x = unit(-0.5, "cm"),
  panel.grid.major = element_blank(), 
  panel.grid.minor = element_blank(),
  plot.title = element_text(hjust = 0.5, size=10))  +
  guides(color = guide_legend(keywidth = 1, keyheight = 0.8, 
                              override.aes = list(size = 3)))

ggdraw() +
  draw_plot(p2 + theme(plot.margin = margin(0, 2, 2, 2))) +  # Main plot
  draw_plot(p1, x = 0.5, y = -0.01, width = 0.5, height = 0.56)  # Inset plot

## -----------------------------------------------------------------------------
colData(spe) <- cbind(colData(spe), data[, c("P1","P2","P3")])
getSpatialExternalMetrics(object=spe, true="reference", pred="P3", k=6)

## -----------------------------------------------------------------------------
getSpatialExternalMetrics(true=colData(spe)$reference, pred=colData(spe)$P3,
                          location=spatialCoords(spe), k=6)

## ----fig.height = 2.5, fig.width = 5------------------------------------------
res3 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P3", 
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

res2 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

res1 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="dataset", k=6,
                                  metrics=c("nsARI","SpatialARI"))

cbind(bind_rows(list(res1, res2, res3), .id="P")) %>% 
  pivot_longer(cols=c("nsARI", "SpatialARI"), 
               names_to="metric", values_to="value") %>%
  ggplot(aes(x=P, y=value, group=metric)) +
  geom_point(size=3, aes(color=P)) +
  facet_wrap(~metric, scales = "free") +
  theme_bw() + labs(x="Prediction")

## -----------------------------------------------------------------------------
res3 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P3",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))

res2 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))

res1 <- getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="class", k=6,
                                  metrics=c("nsAWH","nsAWC"))
res1

## -----------------------------------------------------------------------------
awh1 <- na.omit(res1[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P1)[cluster])
awh2 <- na.omit(res2[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P2)[cluster])
awh3 <- na.omit(res3[,c("nsAWH", "cluster")]) %>% 
  mutate(cluster = levels(data$P3)[cluster])

awh <- cbind(bind_rows(list(awh1, awh2, awh3), .id="P")) %>% 
  pivot_wider(names_from = cluster, values_from = nsAWH) %>%
  subset(select = -c(P))
awh <- as.matrix(awh)
rownames(awh) <- c("P1", "P2", "P3")
awh <- awh[,c("1", "2", "3", "4", "5", "6", "7", "8")]

awh <- data.frame(awh)
colnames(awh) <- seq_len(8)
awh$prediction <- rownames(awh)


p4 <- awh %>% pivot_longer(cols=-c("prediction"), names_to="cluster", 
                           values_to = "AWH") %>%
  mutate(prediction = factor(prediction), cluster=factor(cluster)) %>%
  ggplot(aes(cluster, prediction, fill=AWH)) +
  geom_tile() +
  scale_fill_distiller(palette = "RdPu") +
  labs(x="Predicted domain", y="")


## -----------------------------------------------------------------------------
awc1 <- na.omit(res1[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])
awc2 <- na.omit(res2[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])
awc3 <- na.omit(res3[,c("nsAWC", "class")]) %>% 
  mutate(class = levels(data$reference)[class])

awc <- cbind(bind_rows(list(awc1, awc2, awc3), .id="P")) %>% 
  pivot_wider(names_from = class, values_from = nsAWC) %>%
  subset(select = -c(P))
awc <- as.matrix(awc)
rownames(awc) <- c("P1", "P2", "P3")


awc <- data.frame(awc)
awc$prediction <- rownames(awc)


p5 <- awc %>% pivot_longer(cols=-c("prediction"), names_to="class", 
                           values_to = "AWC") %>%
  mutate(prediction = factor(prediction), class=factor(class)) %>%
  ggplot(aes(class, prediction, fill=AWC)) +
  geom_tile() +
  scale_fill_distiller(palette = "RdPu") +
  labs(x="Annotated domain", y="")


## ----fig.height = 2, fig.width = 9.5------------------------------------------
plot_grid(p4, p5, rel_widths=c(1,1), scale=c(1, 1))

## ----fig.height = 2.5, fig.width = 8------------------------------------------
res1 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P1",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE), 
              data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])

res2 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P2",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE),
               data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])

res3 <- cbind(getSpatialExternalMetrics(object=spe, true="reference", pred="P3",
                                  level="element",
                                  metrics=c("nsSPC"), k=6,
                                  useNegatives = FALSE),
               data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])


cbind(bind_rows(list(res1, res2, res3), .id="P")) %>% 
  pivot_longer(cols=c("nsSPC"), 
               names_to="metric", values_to="value") %>%
  ggplot(aes(x = pxl_col_in_fullres, y = - pxl_row_in_fullres, color = value)) +
  scale_colour_gradient(high="white", low ="deeppink4") +
  geom_point(size=0.3) +
  facet_wrap(~P, scales = "free") +
  theme_bw() + labs(x="Prediction", y="", color="nsSPC")

## -----------------------------------------------------------------------------
# P4: add 20% random noise 
data$P4 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.2
)

# P5: add 30% random noise
data$P5 <- simulate_clustering_variation(
  data$reference,
  noise_level = 0.3
)

## -----------------------------------------------------------------------------
getSpatialInternalMetrics(object=spe, labels="reference", k=6)

## -----------------------------------------------------------------------------
getSpatialInternalMetrics(labels=colData(spe)$reference, 
                          location = spatialCoords(spe), k=6)

## -----------------------------------------------------------------------------
colData(spe) <- cbind(colData(spe), data[, c("P4","P5")])
internal <-lapply(setNames(c("reference","P1","P2","P3","P4","P5"), 
                           c("reference","P1","P2","P3","P4","P5")),
function(x){getSpatialInternalMetrics(object=spe, labels=x,
                                      k=6, level="dataset", 
                                      metrics=c("PAS", "ELSA", "CHAOS"))})
internal <- bind_rows(internal,.id = "prediction")

## ----fig.height = 2.5, fig.width = 5------------------------------------------
internal %>% 
  pivot_longer(cols=-c("prediction"), 
               names_to="metric", values_to="value") %>%
  filter(metric %in% c("ELSA", "PAS", "CHAOS")) %>%
  ggplot(aes(x=prediction, y=value, group=metric)) +
  geom_point(size=3, aes(color=prediction)) +
  facet_wrap(~metric, scales = "free") +
  theme_bw() + labs(x="", color="") +
  theme(legend.position="None",
        axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 1))

## ----fig.height = 4, fig.width = 8--------------------------------------------
internal <-lapply(setNames(c("reference","P1","P2","P3","P4","P5"), 
                           c("reference","P1","P2","P3","P4","P5")),
function(x){cbind(
  getSpatialInternalMetrics(object=spe, labels=x,
                k=6, level="element", metrics=c( "ELSA")), 
  data[,c("pxl_col_in_fullres", "pxl_row_in_fullres")])})
internal <- bind_rows(internal,.id = "prediction")

internal %>%
  ggplot(aes(x = pxl_col_in_fullres, y = - pxl_row_in_fullres, color = ELSA)) +
  scale_colour_gradient(low="white", high="deeppink4") +
  geom_point(size=0.4) +
  facet_wrap(~prediction, scales = "free") +
  theme_bw() + labs(x="", y="", color="ELSA")

## -----------------------------------------------------------------------------
sessionInfo()

