Improving Across-Dataset Brain Tissue Segmentation Using Transformer

Brain tissue segmentation has demonstrated great utility in quantifying MRI data through Voxel-Based Morphometry and highlighting subtle structural changes associated with various conditions within the brain. However, manual segmentation is highly labor-intensive, and automated approaches have struggled due to properties inherent to MRI acquisition, leaving a great need for an effective segmentation tool. Despite the recent success of deep convolutional neural networks (CNNs) for brain tissue segmentation, many such solutions do not generalize well to new datasets, which is critical for a reliable solution. Transformers have demonstrated success in natural image segmentation and have recently been applied to 3D medical image segmentation tasks due to their ability to capture long-distance relationships in the input where the local receptive fields of CNNs struggle. This study introduces a novel CNN-Transformer hybrid architecture designed for brain tissue segmentation. We validate our model's performance across four multi-site T1w MRI datasets, covering different vendors, field strengths, scan parameters, time points, and neuropsychiatric conditions. In all situations, our model achieved the greatest generality and reliability. Out method is inherently robust and can serve as a valuable tool for brain-related T1w MRI studies. The code for the TABS network is available at: https://github.com/raovish6/TABS.

Despite the demonstrated utility of brain tissue segmentation, there is no universally accepted method capable of segmenting accurately and efficiently across a wide variety of datasets. Manual segmentation of brain tissue is extremely labor intensive, often impractical given larger datasets, and difficult even for experts. Alternatively, automated segmentation has proven challenging due to properties inherent to the MRI scans themselves. Changes in vendors or field strength have both been linked with increased variance in repeated scan measures (Han et al., 2006), and scans acquired through different imaging protocols tend to fluctuate more in terms of volumetric brain measures (Kruggel et al., 2010). Time of day, as well as time between scans, have been associated with variable tissue volume estimation (Karch et al., 2019) while neuropsychiatric conditions such as schizophrenia have been linked with subtle brain tissue anatomical changes (Koutsouleris et al., 2014). Together, these inconsistencies make it difficult for brain tissue segmentation solutions to be applicable across datasets of differing vendors, collection parameters, time points, and neuropsychiatric conditions. Many of the earlier proposed automated solutions have depended on intensity thresholding (Dora et al., 2017), population-based atlases (Cabezas et al., 2011), clustering (Dora et al., 2017;Mahmood et al., 2015), statistical methods (Angelini et al., 2007;Greenspan et al., 2006;Marroquín et al., 2002;Zhang et al., 2000), and standard machine learning algorithms. Thresholding-based approaches often struggle to segment low contrast input images with overlapping brain tissue intensity histograms. Alternatively, atlas-based algorithm performance heavily depends on the quality of the population-derived brain atlas. While machine learning algorithms such as support vector machine (SVM) (Bauer et al., 2011), random forest (Dadar and Collins, 2021), and neural networks (Amiri et al., 2013) have demonstrated reasonable segmentation performance, their accuracy largely relies on the quality of manually extracted features. In general, many of these algorithms require a priori information to properly segment brain tissue, which is often not feasible to acquire for all new scans segmented. FSL FAST is a popular statistical brain tissue segmentation toolkit that combines Gaussian mixture models with hidden Markov random fields to achieve reliable segmentation performance across a variety of datasets (Zhang et al., 2000). However, segmentation via FAST is time consuming and therefore not ideal for many real-time segmentation applications.
Convolutional neural networks (CNNs) have recently emerged as a superior alternative to standard machine learning algorithms for classification-based brain segmentation given their feature-encoding capabilities (Akkus et al., 2017). CNNs have been found to outperform machine learning algorithms such as random forest and SVM specifically for brain tissue segmentation (Zhang et al., 2015). Following their introduction, many other CNN-based networks have been proposed for brain tissue segmentation (Khagi and Kwon, 2018;Moeskops et al., 2016) as well as brain tumor segmentation (Beers et al., 2017;Feng et al., 2020a;Mlynarski et al., 2019), including both 2D and 3D approaches. Unet represents one popular segmentation algorithm (Çiçek et al., 2016;Ronneberger et al., 2015), which consists of symmetric encoding and decoding convolutional operations that allows for the preservation of the initial image resolution following segmentation. Variants of Unet have been successfully applied to brain tissue segmentation achieving state-of-the-art performance. For example, one study achieved a DICE score of 0.988 using 3D Unet, which even outperformed human experts (Kolařík et al., 2018). More recently, 2D patch-based Unet and Unet-inspired implementations have gained traction Yamanakkanavar and Lee, 2020) to better preserve and account for local details; such models have outperformed their non-patch-based variants.
Despite the impressive performance CNNs have demonstrated for brain tissue segmentation, they often struggle to generalize well when presented with new datasets. Many prior brain tissue segmentation approaches only report test performance on the same dataset upon which the model was trained. While such metrics validate the generality of the proposed model on MRI scans from the same dataset, they fail to quantify model performance across different datasets where changes in acquisition parameters can impact MRI image features and thus decrease the model's generality. Given the importance of brain tissue segmentation in VBM and pre-processing, it is not practical to retrain a CNN model every time a scan is obtained differently. As such, model generality is especially imperative to developing a widely applicable automated brain tissue segmentation solution.
Transformers are an alternative to CNNs that have recently demonstrated state-of-the-art results in natural image segmentation. Emerging evidence suggests that Transformers coupled with CNNs may improve performance and generalization for medical image segmentation tasks including brain tissue segmentation Hatamizadeh et al., 2021;Sun et al., 2021;Wang et al., 2021). In this study, we sought to improve the traditional Unet architecture using Transformers to not only achieve higher brain tissue segmentation performance, but also generalize better across different datasets while remaining reliable. Here, we propose Transformer-based Automated Brain Tissue Segmentation (TABS), a new 3D CNN-Transformer hybrid deep learning architecture for brain tissue segmentation.
Our main contributions include: 1. A novel CNN-Transformer hybrid architecture designed for brain tissue segmentation. 2. We elucidate the benefits of embedding a Transformer module within a CNN encoder-decoder architecture for brain tissue segmentation. 3. After achieving improved within dataset performance, we are the first to rigorously demonstrate model generality and reliability across multiple vendors, field strengths, scan parameters, time points, and neuropsychiatric conditions.

Study Design
We conducted three experiments to evaluate model performance, generality, and reliability for brain tissue segmentation. The experimental pipeline for these experiments is visualized in Figure 1. First, we trained and tested all of the models on three separate datasets (DLBS, SALD, and IXI) of differing acquisition parameters along with an aggregate total dataset containing all of the scans combined. We then evaluated model generality across field strength and scanner parameters; models trained on 3T datasets were tested on the 1.5T dataset and models trained on 3T datasets were tested on one another. Finally, we extended our generalization testing to an alternate dataset (COBRE) containing test-retest repeated scans of both schizophrenia and healthy patients. We applied models pre-trained on the 3T SALD dataset to COBRE to give them the best chance of generalizing well, as SALD and COBRE were collected using similar acquisition parameters. We compared the reliability of TABS, the best generalizing model, to that of the ground truth by evaluating the similarity of outputs on the test-retest repeated scans. Given that each pair of scans were acquired from the same subject within a small time frame, we expected a more reliable tool to output very similar segmentation predictions across both scans.
We compared TABS to three other benchmark CNN models in our experiments: vanilla Unet, Unet-SE, and ResUnet. We chose Unet given its prior state of the art performance in 3D brain tissue segmentation (Kolařík et al., 2018), and we also compared to prior attempts at improving Unet including squeeze-excitation (SE) blocks (Hu et al., 2018) before each downsampling operation (Unet-SE) and residual connections (ResUnet) (Zhang et al., 2018). Moreover, given that the model architecture for TABS is identical to that of ResUnet except for the Vision Transformer, comparing to ResUnet allowed us to highlight the specific benefits conferred by the Transformer. All of the tested models were the same depth and encoded the same number of features. Finally, we also compared to FSL FAST, the tool used to generate the ground truths, in our reliability evaluation. Model generality test, where models pre-trained on 3T DLBS/SALD datasets were tested on one another and on the 1.5T IXI dataset. C. Model reliability test, where the best generalizing model to the COBRE dataset was compared to FAST based on similarity in segmentation outputs for repeated scans.

Data Selection and Pre-Processing
We collected MRI scans of healthy participants over a broad age range from three datasets for our first two experiments: DLBS (Rodrigue et al., 2012), SALD (Wei et al., 2018), and IXI (Biomedical Image Analysis Group, 2018). While they all use an MPRAGE sequence, the datasets vary in terms of their other acquisition parameters. Firstly, they differ by field strength, where DLBS and SALD contain 3T scans and IXI contains 1.5T scans. Moreover, all three datasets were acquired using different scanners, with the SALD dataset acquired using a Siemens manufactured scanner as opposed to Phillips. Lastly, the datasets differ in terms of scan parameters such as repetition/echo time and flip angle. We split each dataset into 3:1:1 train/validation/test groups while maintaining a broad age distribution across each subsection. The age distributions across these splits for each of these datasets are shown in Figure 2a-c. We also collected paired test-retest scans taken at different time points of healthy participants and schizophrenia patients from the COBRE dataset (Bustillo et al., 2017) for our third experiment. The demographic information and acquisition parameters for all four datasets are outlined in Table 1.  We followed the initial pre-processing protocol outlined by Feng et al. (2020b) for all of the datasets, which includes bias field correction (Sled et al., 1998), brain extraction using FreeSurfer (Ségonne et al., 2004), and affine registration to the 1 mm 3 isotropic MNI152 brain template with trilinear interpolation using FSL FLIRT (Jenkinson et al. 2002). After these steps, the DLBS/SALD/IXI MRI images were 182x218x182, and the COBRE images were 193x229x193. We padded and cropped the images to reach an input dimension of 192x192x192, using a maximum intensity projection across all scans for each dataset to ensure that we did not remove important anatomical components. Finally, we normalized the intensities for each scan to values between -1 and 1. The pre-processing pipeline is shown in Figure 2d.

Model Architecture and Implementation
The architecture of our proposed model is shown in Figure 3. TABS is a ResUnet (Zhang et al., 2018) inspired model that consists of a 5-layered 3D CNN encoder and decoder. TABS takes an input dimension of 192x192x192, and the five encoder layers downsample the original image to f x12x12x12, where f represents the number of encoded features. For this specific implementation, we chose a f value of 128. We follow the same "linear projection and learned positional embedding" operations introduced in Wang et al. (2021) to convert the encoded feature tensor into 512 tokenized vectors that are sequentially fed into the Transformer module in the order determined by the learned positional embeddings. Our Transformer encoder consists of 4 layers and 8 heads following the implementation initially described by Vaswani et al. (2017). The output of the Transformer is 512x1728, which we then reshape to 512x12x12x12 and reduce the feature dimensionality to f via convolution. The decoder portion of the network reconstructs the image to the original input dimension, and a final convolution operation is applied to generate a 3-channel output with each channel corresponding to an individual tissue type. We used a Softmax activation function to ensure that the probabilities for each voxel across the three channels add up to 1.

Training Protocol
All four models were trained using the same parameters described below. We trained for 350 epochs on the three individual datasets, while we trained over 200 epochs for the larger total dataset. We selected pre-trained models based on the best validation performance. We used FAST to generate ground truth probability maps for each brain tissue type and stacked and cropped them to generate a three-channel image matching the output shape of our models (3x192x192x192). The models were trained on three 24 GB NVIDIA Quadro 6000 graphical processing units using mean-squared-error (MSE) loss with a batch size of 3. We used group normalization as opposed to batch normalization due to group normalization's increased stability for smaller batch sizes (Wu and He, 2018). We trained using Adam (Kingma and Ba, 2014) as the optimization algorithm with a learning rate of 1E-5 and weight decay set to 1E-6.

Evaluation Metrics
All evaluation metrics were only taken for the portion of the outputs containing the brain, meaning that the background voxels outside of the segmentation field were not considered. Additionally, all metrics were calculated individually for each brain tissue type. Segmentation similarity using continuous probability estimates was quantified using Pearson correlation, Spearman correlation, and MSE. Segmentation maps for each tissue type were then generated from the probability estimations by taking the argmax along the channel axis. We generated binary maps for each tissue type based on the numerical value assigned to each voxel of the argmax output. Segmentation similarity between these binary maps was quantified using DICE Score, Jaccard Index, and Haussdorf Distance (HD) (Beauchemin et al., 1998). Performance was compared between models based on the higher absolute value of the metric.

Model Performance
The performance results for each model trained and tested on DLBS, IXI, SALD, and Total datasets individually are reported in Table 2. TABS outperformed ResUnet, Unet-SE, and Unet on all the datasets for most metrics except for the 1.5T IXI dataset, where TABS outperformed Unet-SE and Unet while only performing slightly worse than ResUnet. TABS consistently achieves higher DICE/Jaccard metrics across all tissue types along with higher correlation and lower MSE on most tissue types. In general, all models performed better on WM and CSF as opposed to GM. Figure 4 plots representative segmentation outputs for performance testing for each of the datasets.

Model Generality -DLBS, IXI, and SALD
The generality results for all models trained on DLBS/SALD and applied to IXI as well as trained on DLBS/SALD and applied to SALD/DLBS are shown in Table 3. TABS generalized better across datasets on most metrics for the DLBS→IXI and SALD→DLBS tests, with higher DICE/Jaccard and correlation metrics for at least two tissue types. Additionally, for the SALD→IXI generalization test, TABS reached higher DICE/Jaccard metrics for both GM and WM. We observed that models trained on SALD performed better when applied to IXI than models trained on IXI itself. TABS also exhibited a similar increase in performance when pre-trained on DLBS and applied to IXI compared to TABS trained on IXI. Representative segmentation outputs for all models for each test scenario is shown in Figure 5.

COBRE Test-Retest
TABS showcased better reliability compared to FAST, the tool used to generate the ground truths. Similarity metrics between test-retest repeated images for both TABS and FAST are shown in Table 5 for the control, schizophrenia, and total aggregate datasets. TABS proved consistently more reliable across almost all metrics for GM and CSF. Moreover, TABS reached a higher Pearson correlation and lower MSE over all tissue types, and only performed slightly worse than FAST on WM DICE/Jaccard. Representative segmentation outputs for paired repeated scans from both control and schizophrenia datasets are visualized in Figure 6.

Discussion
In this study, we present TABS, a new Transformer-CNN hybrid deep learning architecture designed for brain tissue segmentation. TABS showcased superior performance compared to prior state-of-the-art CNN implementations while also generalizing exceptionally well across datasets and remaining reliable between paired test-retest scans. These traits are critical to developing a useful and more widely applicable brain tissue segmentation toolkit. Through TABS, we also demonstrate the methodological utility using a Vision Transformer to improve the Unet architecture for brain tissue segmentation.
Our experimental protocol was designed to elucidate the real-world applicability of TABS compared to various benchmark models. The datasets included in this study were chosen with the goal of emulating the extreme differences in MRI input a brain tissue segmentation algorithm would receive in real-world applications; the DLBS, SALD, and IXI datasets varied in terms of manufacturer, field strengths, and scanner parameters. Moreover, our test-retest dataset consisted of repeated scans from schizophrenia and healthy patients taken at different time points, presenting an even more challenging segmentation task. Due to these factors, we believe our evaluation methodology accurately captures the versatility of TABS.

Superior Performance for TABS Compared to Benchmark Unet Models
We found that TABS was the best performing model when trained and tested on the same dataset. While TABS achieved significantly higher performance than both Unet and Unet-SE, we observed marginal performance benefits over ResUnet. We hypothesize that the residual connections are responsible for the bulk of the performance gain over the traditional Unet models, with the Transformer module providing a small but consistent performance increase within the datasets.

Superior Generality for TABS Compared to Benchmark Unet Models
TABS generalized the best on most datasets compared to the benchmark Unet models. The most significant generalization differences we observed were between TABS and ResUnet. Given that their model architectures are identical except for the Transformer, we believe that the addition of the Transformer significantly improves model generality. CNNs are not well suited to capture long-range dependencies in the input image due to the local receptive fields of convolutional kernels. We believe that this property could make Transformer-based networks agnostic to dataset-specific variations and thus more generalizable. The addition of the Transformer allows TABS to preserve and even improve the within dataset performance conferred by residual connections while also generalizing better than the vanilla Unet, where ResUnet struggled.
We also noticed that all of the models tested improved in performance when trained on SALD and applied to IXI as opposed to training on IXI itself. This disparity could be due to the difference in field strength: the higher quality 3T MRI images from SALD may provide more globally relevant features than the 1.5T MRI images from IXI. However, for TABS specifically, we observed this same effect when pre-trained on 3T DLBS scans. These results indicate that TABS can potentially take better advantage of higher quality training data compared to the benchmark models.
Finally, we found that TABS generalized the best on an alternate COBRE dataset consisting of both healthy and schizophrenia scans. Schizophrenia patients often reflect subtle anatomical differences compared to healthy subjects, such as alterations in GM volume (Koutsouleris et al., 2014). These changes make generalizing to the schizophrenia dataset an especially difficult task. Additionally, the mean age of the COBRE dataset was slightly lower than the datasets TABS was originally trained on, making generalizing to COBRE potentially even more challenging. TABS generalized the best compared to the benchmark models on the overall COBRE dataset, with even more pronounced differences for the schizophrenia portion. Therefore, we believe that TABS may excel in more difficult segmentation cases where standard Unet models yield errors.

Superior Reliability for TABS Compared to FAST
Finally, our test-retest experiment highlights the reliability of TABS, the best generalizing model on the COBRE dataset, compared with the ground truth FAST. The test-retest repeated scans used in this study were taken from the same patient within a short time frame, meaning that we expected minimal differences in the segmentation output. One of the primary advantages of FAST has been its generality and reliability. Through this test, we find that TABS not only generalizes well on the COBRE dataset, but also maintains this performance more reliably than FAST.

Limitations and Future work
In general, 3D CNN models require a large amount of computational power to efficiently train. While we were able to use full resolution MRI inputs for our model, we were limited to a batch size of 3 due to memory constraints. Using a larger batch size may have resulted in better performance. Additionally, even though we trained TABS on three large datasets, our performance could be further improved by increasing our sample size. Recent findings suggest that patchbased 2D CNN approaches perform better than non-patch-based variants for brain tissue segmentation Yamanakkanavar et al., 2020). As such, we believe that we could extend TABS to a patch-based 3D model in future studies to better capture local information that may be lost by processing the entire image at once.

Conclusion
In conclusion, we believe TABS represents a compelling brain tissue segmentation alternative. TABS performs and generalizes better than comparable state-of-the-art CNN models across vendor, field strength, scan parameters, and neuropsychiatric condition while remaining consistent across time points. Our results also demonstrate that the embedding of a Transformer module between the encoder and decoder portions of a CNN architecture represents an efficient method to improve brain tissue segmentation performance and generality.

Data and Code Availability Statement
The code used in this project is proprietary. The code for the TABS model is available at https://github.com/raovish6/TABS, and the entire TABS package is available upon request of the corresponding author.
The code for TABS is © 2021 The Trustees of Columbia University in the City of New York. This work may be reproduced and distributed for academic non-commercial purposes only.