--- title: "batchmix workflow" author: "Stephen Coleman" output: rmarkdown::html_vignette description: | Tutorial for clustering/classifying with batchmix. vignette: > %\VignetteIndexEntry{Introduction to batchmix} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE) set.seed(1) ``` ## Introduction This document shows the basics of applying our Bayesian model-based clustering/classification with joint batch correction in ``R``. It shows how to generate some toy data, apply the model, assess convergence and process outputs. ## Data generation We simulate some data using the ``generateBatchData`` function. ```{r dataGen} library(ggplot2) library(batchmix) # Data dimensions N <- 600 P <- 4 K <- 5 B <- 7 # Generating model parameters mean_dist <- 2.25 batch_dist <- 0.3 group_means <- seq(1, K) * mean_dist batch_shift <- rnorm(B, mean = batch_dist, sd = batch_dist) std_dev <- rep(2, K) batch_var <- rep(1.2, B) group_weights <- rep(1 / K, K) batch_weights <- rep(1 / B, B) dfs <- c(4, 7, 15, 60, 120) my_data <- generateBatchData( N, P, group_means, std_dev, batch_shift, batch_var, group_weights, batch_weights, type = "MVT", group_dfs = dfs ) ``` This gives us a named list with two related datasets, the ``observed_data`` which includes batch effects and the ``corrected_data`` which is batch-free. It also includes ``group_IDs``, a vector indicating class membership for each item, ``batch_IDs``, which indicates batch of origin for each item, and ``fixed``, which indicates which labels are observed and fixed in the model. We pull these out of the names list in the format that the modelling functions desire them. ```{r dataClean} X <- my_data$observed_data true_labels <- my_data$group_IDs fixed <- my_data$fixed batch_vec <- my_data$batch_IDs alpha <- 1 initial_labels <- generateInitialLabels(alpha, K, fixed, true_labels) ``` ## Modelling Given some data, we are interested in modelling it. We assume here that the set of observed labels includes at least one example of each class in the data. ```{r runMCMCChains} # Sampling parameters R <- 1000 thin <- 50 n_chains <- 4 # Density choice type <- "MVT" # MCMC samples and BIC vector mcmc_output <- runMCMCChains( X, n_chains, R, thin, batch_vec, type, initial_labels = initial_labels, fixed = fixed ) ``` We want to assess two things. First, how frequently the proposed parameters in the Metropolis-Hastings step are accepted: ```{r plotAcceptanceRatesEarly} plotAcceptanceRates(mcmc_output) ``` Secondly, we want to asses how well our chains have converged. To do this we plot the ``complete_likelihood`` of each chain. This is the quantity most relevant to a clustering/classification, being dependent on the labels. The ``observed_likelihood`` is independent of labels and more relevant for density estimation. ```{r likelihood} plotLikelihoods(mcmc_output) ``` We see that our chains disagree. We have to run them for more iterations. We use the ``continueChains`` function for this. ```{r continueChains} R_new <- 9000 # Given an initial value for the parameters new_output <- continueChains( mcmc_output, X, fixed, batch_vec, R_new, keep_old_samples = TRUE ) ``` To see if the chains better agree we re-plot the likelihood. ```{r continuedLikelihood} plotLikelihoods(new_output) ``` We also re-check the acceptance rates. ```{r plotAcceptanceRates} plotAcceptanceRates(new_output) ``` This looks like several of the chains agree by the 5,000th iteration. ## Process chains We process the chains, acquiring point estimates of different quantities. ```{r processChains} # Burn in burn <- 5000 # Process the MCMC samples processed_samples <- processMCMCChains(new_output, burn) ``` ## Visualisation For multidimensional data we use a PCA plot. ```{r pca} chain_used <- processed_samples[[1]] pc <- prcomp(X, scale = T) pc_batch_corrected <- prcomp(chain_used$inferred_dataset) plot_df <- data.frame( PC1 = pc$x[, 1], PC2 = pc$x[, 2], PC1_bf = pc_batch_corrected$x[, 1], PC2_bf = pc_batch_corrected$x[, 2], pred_labels = factor(chain_used$pred), true_labels = factor(true_labels), prob = chain_used$prob, batch = factor(batch_vec) ) plot_df |> ggplot(aes( x = PC1, y = PC2, colour = true_labels, alpha = prob )) + geom_point() plot_df |> ggplot(aes( x = PC1_bf, y = PC2_bf, colour = pred_labels, alpha = prob )) + geom_point() test_inds <- which(fixed == 0) sum(true_labels[test_inds] == chain_used$pred[test_inds]) / length(test_inds) ```