## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  echo = TRUE,
  message = FALSE,
  warning = FALSE,
  collapse = TRUE,
  comment = "#>"
)

## ----installation, eval=FALSE-------------------------------------------------
# if (!require("BiocManager", quietly = TRUE)) {
#   install.packages("BiocManager")
# }
# 
# BiocManager::install("CellMentor")

## ----load_packages------------------------------------------------------------
set.seed(100)

library(CellMentor)
library(Matrix)
library(ggplot2)
library(SingleCellExperiment)
library(scater) 

## ----load_data----------------------------------------------------------------
# Loading reference dataset (Baron)
baron <- hBaronDataset()
reference_matrix <- baron$data
reference_celltypes <- baron$celltypes

# Loading query dataset (Muraro)
muraro <- muraro_dataset()
query_matrix <- muraro$data
query_celltypes <- muraro$celltypes # This would be unknown in a real application
# We keep it here for evaluation

## -----------------------------------------------------------------------------
# Function to create balanced subsets
create_subset <- function(matrix, celltypes, cells_per_type = 30) {
  # Get unique cell types
  unique_types <- unique(celltypes)

  # Select cells for each type
  selected_cells <- c()
  for (cell_type in unique_types) {
    # Get cells of this type
    type_cells <- names(celltypes)[celltypes == cell_type]

    # If fewer cells than requested, take all of them
    n_to_select <- min(cells_per_type, length(type_cells))

    # Randomly select cells
    selected <- sample(type_cells, n_to_select)
    selected_cells <- c(selected_cells, selected)
  }

  # Return subset
  list(
    matrix = matrix[, selected_cells],
    celltypes = celltypes[selected_cells]
  )
}

# Create balanced subsets with 30 cells per type
baron_subset <- create_subset(reference_matrix, reference_celltypes, 30)
muraro_subset <- create_subset(query_matrix, query_celltypes, 30)

# Update variable names for clarity
reference_matrix <- baron_subset$matrix
reference_celltypes <- baron_subset$celltypes
query_matrix <- muraro_subset$matrix
query_celltypes <- muraro_subset$celltypes # This would be unknown in a real application
# We keep it here for evaluation

## ----create_object------------------------------------------------------------
# Create the CSFNMF object
csfnmf_obj <- CreateCSFNMFobject(
  ref_matrix = reference_matrix,
  ref_celltype = reference_celltypes,
  data_matrix = query_matrix,
  norm = TRUE,
  most.variable = TRUE,
  scale = TRUE,
  scale_by = "cells",
  verbose = TRUE,
  num_cores = 1
)

## ----run_cellmentor-----------------------------------------------------------
# Run CellMentor with hyperparameter optimization
optimal_params <- CellMentor(
  csfnmf_obj,
  alpha_range = c(1, 5), # Limited alpha range
  beta_range = c(5), # use only one beta for speed
  gamma_range = c(0.1), # use only one gamma for speed
  delta_range = c(1), # use only one delta for speed
  num_cores = 1,
  verbose = TRUE
)

# Get best model
best_model <- optimal_params$best_model
K_VALUE <- cm_rank(best_model)

## ----project_data-------------------------------------------------------------
# Project query data onto the learned space
h_project <- project_data(
  W = W(best_model), # Learned gene weights
  X = data_matrix(matrices(best_model)), # Query data matrix
  num_cores = 5,
  verbose = TRUE
)

## -----------------------------------------------------------------------------
# Ensure unique rownames for genes
rownames(query_matrix) <- make.unique(rownames(query_matrix))

sce <- SingleCellExperiment(
  assays = list(counts = query_matrix)
)

# Store any per-cell annotations you have (e.g., CellMentor cell types)
colData(sce)$celltype <- query_celltypes

# Cell embeddings (cells x K)
H_cell <- t(as.matrix(h_project))      # ensure cells x K
colnames(H_cell) <- paste0("CM", seq_len(ncol(H_cell)))
reducedDim(sce, "CellMentor") <- H_cell

# Gene loadings (genes x K) — align to SCE row order
W_mat <- as.matrix(W(best_model))      # genes x K
# Make sure rownames are set and alignable
if (!is.null(rownames(W_mat))) {
  # Match to SCE rows; missing become NA
  W_mat <- W_mat[match(rownames(sce), rownames(W_mat)), , drop = FALSE]
  # Add each factor as a rowData column for convenience
  for (j in seq_len(ncol(W_mat))) {
    rowData(sce)[[paste0("CM_loading_", j)]] <- W_mat[, j]
  }
}

# UMAP using the precomputed CellMentor embedding
sce <- runUMAP(
  sce,
  dimred = "CellMentor",
  name   = "UMAP_CellMentor",
  ncomponents = 2
)

# Quick plots
plotReducedDim(sce, dimred = "UMAP_CellMentor", colour_by = "celltype")

## ----custom_data_workflow, eval=FALSE-----------------------------------------
# library(Matrix)
# library(CellMentor)
# 
# # 1) Build CSFNMF object
# csfnmf_obj <- CreateCSFNMFobject(
#   ref_matrix = ref_counts,
#   ref_celltype = ref_celltypes, # names(ref_celltypes) == colnames(ref_counts)
#   data_matrix = qry_counts,
#   norm = TRUE,
#   most.variable = TRUE,
#   scale = TRUE,
#   scale_by = "cells",
#   num_cores = 1,
#   verbose = TRUE
# )
# 
# # 2) Hyperparameter search & training
# optimal <- CellMentor(csfnmf_obj)
# 
# # 3) Get best model
# best_model <- optimal$best_model
# 
# # 4) Project data
# h_project <- project_data(
#   W = W(best_model),
#   X = data_matrix(matrices(best_model))
# )
# 
# # 5) Optional: Seurat integration & UMAP
# # (Follow the same steps as in the demo section above)

## ----session_info-------------------------------------------------------------
sessionInfo()

