## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    library(VAExprs)
    
    ### simulate differentially expressed genes
    set.seed(1)
    g <- 3
    n <- 100
    m <- 1000
    mu <- 5
    sigma <- 5
    mat <- matrix(rnorm(n*m*g, mu, sigma), m, n*g)
    rownames(mat) <- paste0("gene", seq_len(m))
    colnames(mat) <- paste0("cell", seq_len(n*g))
    group <- factor(sapply(seq_len(g), function(x) { 
        rep(paste0("group", x), n)
    }))
    names(group) <- colnames(mat)
    mu_upreg <- 6
    sigma_upreg <- 10
    deg <- 100
    for (i in seq_len(g)) {
        mat[(deg*(i-1) + 1):(deg*i), group == paste0("group", i)] <- 
            mat[1:deg, group==paste0("group", i)] + rnorm(deg, mu_upreg, sigma_upreg)
    }
    # positive expression only
    mat[mat < 0] <- 0
    x_train <- as.matrix(t(mat))
    
    # heatmap
    heatmap(mat, Rowv = NA, Colv = NA, 
            col = colorRampPalette(c('green', 'red'))(100), 
            scale = "none")
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    # model parameters
    batch_size <- 32
    original_dim <- 1000
    intermediate_dim <- 512
    epochs <- 100
    
    # VAE
    vae_result <- fit_vae(x_train = x_train, x_val = x_train,
                        encoder_layers = list(layer_input(shape = c(original_dim)),
                                            layer_dense(units = intermediate_dim,
                                                        activation = "relu")),
                        decoder_layers = list(layer_dense(units = intermediate_dim,
                                                        activation = "relu"),
                                            layer_dense(units = original_dim,
                                                        activation = "sigmoid")),
                        epochs = epochs, batch_size = batch_size,
                        use_generator = FALSE,
                        callbacks = keras::callback_early_stopping(
                            monitor = "val_loss",
                            patience = 10,
                            restore_best_weights = TRUE))
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    # model architecture
    plot_vae(vae_result$model)
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    # sample generation
    set.seed(1)
    gen_sample_result <- gen_exprs(vae_result, num_samples = 100)
    
    # heatmap
    heatmap(cbind(t(x_train), t(gen_sample_result$x_gen)),
            col = colorRampPalette(c('green', 'red'))(100),
            Rowv=NA)
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    # plot for augmented data
    plot_aug(gen_sample_result, "PCA")
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    library(VAExprs)
    library(SC3)
    library(SingleCellExperiment)
    
    # create a SingleCellExperiment object
    sce <- SingleCellExperiment::SingleCellExperiment(
        assays = list(counts = as.matrix(yan)),
        colData = ann
    )
    
    # define feature names in feature_symbol column
    rowData(sce)$feature_symbol <- rownames(sce)
    # remove features with duplicated names
    sce <- sce[!duplicated(rowData(sce)$feature_symbol), ]
    # remove genes that are not expressed in any samples
    sce <- sce[which(rowMeans(assay(sce)) > 0),]
    dim(assay(sce))
    
    # model parameters
    batch_size <- 32
    original_dim <- 19595
    intermediate_dim <- 256
    epochs <- 100
    
    # model
    cvae_result <- fit_vae(object = sce,
                        encoder_layers = list(layer_input(shape = c(original_dim)),
                                            layer_dense(units = intermediate_dim,
                                                        activation = "relu")),
                        decoder_layers = list(layer_dense(units = intermediate_dim,
                                                        activation = "relu"),
                                            layer_dense(units = original_dim,
                                                        activation = "sigmoid")),
                        epochs = epochs, batch_size = batch_size,
                        use_generator = TRUE,
                        callbacks = keras::callback_early_stopping(
                            monitor = "loss",
                            patience = 20,
                            restore_best_weights = TRUE))
    
    # model architecture
    plot_vae(cvae_result$model)
}

## ----eval=TRUE----------------------------------------------------------------
if (keras::is_keras_available() & reticulate::py_available()) {
    # sample generation
    set.seed(1)
    gen_sample_result <- gen_exprs(cvae_result, 100,
                                batch_size, use_generator = TRUE)
    
    # plot for augmented data
    plot_aug(gen_sample_result, "PCA")
}

## ----eval=TRUE----------------------------------------------------------------
sessionInfo()

