Adversarial counterfactual augmentation: application in Alzheimer’s disease classification

Due to the limited availability of medical data, deep learning approaches for medical image analysis tend to generalise poorly to unseen data. Augmenting data during training with random transformations has been shown to help and became a ubiquitous technique for training neural networks. Here, we propose a novel adversarial counterfactual augmentation scheme that aims at finding the most effective synthesised images to improve downstream tasks, given a pre-trained generative model. Specifically, we construct an adversarial game where we update the input conditional factor of the generator and the downstream classifier with gradient backpropagation alternatively and iteratively. This can be viewed as finding the ‘weakness’ of the classifier and purposely forcing it to overcome its weakness via the generative model. To demonstrate the effectiveness of the proposed approach, we validate the method with the classification of Alzheimer’s Disease (AD) as a downstream task. The pre-trained generative model synthesises brain images using age as conditional factor. Extensive experiments and ablation studies have been performed to show that the proposed approach improves classification performance and has potential to alleviate spurious correlations and catastrophic forgetting. Code: https://github.com/xiat0616/adversarial_counterfactual_augmentation


Introduction
Deep learning has been playing an increasingly important role in medical image analysis in the past decade, with great success in segmentation, diagnosis, detection, etc (1). Although deep-learning based models can significantly outperform traditional machine learning methods, they heavily rely on the large size and quality of training data (2). In medical image analysis, the availability of large dataset is always an issue, due to high expense of acquiring and labelling medical imaging data (3). When only limited training data are available, deep neural networks tend to memorise the data and cannot generalise well to unseen data (4,5). This is known as over-fitting (4). To mitigate this issue, data augmentation has become a popular approach. The aim of data augmentation is to generate additional data that can help increase the variation of the training data.
Conventional data augmentation approaches mainly apply random image transformations, such as cropping, flipping, and rotation etc. to the data. Even though such conventional data augmentation techniques are general, they may not transfer well from one task to another (6). For instance, color augmentation could prove useful for natural images but may not be suitable for MRI images which are presented in greyscale images (3). Furthermore, traditional data augmentation methods may introduce distribution shift, i.e., the change of the joint distribution of inputs and outputs, and consequently adversely impact the performance on nonaugmented data during inference 1 (i.e., during the application phase of the learned model) (7).
Some recently developed approaches learn parameters for data augmentation that can better improve downstream task, e.g. segmentation, detection, diagnosis, etc., performance (6,8,9) or select the hardest augmentation for the target model from a small batch of random augmentations for each traning sample (10). However, these approaches still use conventional image transformations and do not consider semantic augmentation (11), i.e., creating unseen samples by changing semantic information of images such as changing the background of an object or changing the age of a brain image. Semantic augmentation can complement traditional techniques and improve the diversity of augmented samples (11).
One way to achieve semantic augmentation is to train a deep generative model to create counterfactuals, i.e., synthetic modifications of a sample such that some aspects of the original data remain unchanged (12)(13)(14)(15)(16). However, these approaches mostly focus on the training stage of generative models and randomly generate samples for data augmentation, without considering which counterfactuals are more effective for downstream tasks, i.e. data-efficiency of the generated samples. Ye et al. (17) use a policy based reinforcement learning (RL) strategy to select synthetic data for augmentation with reward as the validation accuracy. Xue et al. (18) propose a cGAN based model to augment classification of histopathology images with a selective strategy based on assigned label confidence and feature similarity to real data. By contrast, our approach focuses on finding the weakness (i.e. the hard counterfactuals) of a downstream task model (e.g. a classifier) and forces it to overcome its weakness. Similarly, Ye et al. (17) use a policy based reinforcement learning (RL) strategy to select synthetic data for augmentation, with reward as the validation accuracy, but the instability of RL training could perhaps hinder the utility of their approach. Wang et al. (11), Li et al. (19), Chen and Su (20) proposed to augment the data in the latent space of the target deep neural network, by estimating the covariance matrix of latent features obtained from latent layers of the target deep neural network for each class (e.g., car, horse, tree, etc.) and sampling directions from the feature distributions. These directions should be semantic meaningful such that changing along one direction can manipulate one property of the image, e.g. color of a car. However, there is no guarantee that the found directions will be semantically meaningful, and it is hard to know which direction controls a particular property of interest.
In this work, we consider the scenario that we have a classifier which we want to improve (e.g. an image-based classifier of Alzheimer's Disease (AD) given brain images). We are also given some data and a pre-trained generative model that is able to create new data given an image as input and conditioning factors that can alter corresponding attributes in the input. For example, the generative model can alter the brain age of the input. We propose an approach to guide a pre-trained generative model to generate the most effective counterfactuals via an adversarial game between the input conditioning factor of the generator and the downstream classifier, where we use gradient backpropagation to update the conditioning factor and the classifier alternatively and iteratively. A schematic of the proposed approach is shown in Figure 1.
Specifically, we choose the classification of AD as the downstream task and utilise a pre-trained brain ageing synthesis model to improve the AD classifier. The brain ageing generative model used in this paper is adopted from a recent work (21), which takes a brain image and a target age as inputs and outputs an aged brain image. 2 We show that the proposed approach can improve the test accuracy of the AD classifier. We also demonstrate that it can be used in a continual learning 3 context to alleviate catastrophic forgetting, i.e. deep models forget what they have learnt from previous data when training on new given data, and can be used to alleviate spurious correlations, i.e. two variables appear to be causally related to one another but in fact they are not. Our contributions can be summarised as follows: conditional input and the classifier. To the best of our knowledge, this is the first approach that formulates such an adversarial scheme to utilise pre-trained generators in medical imaging. 2. We improve a recent brain ageing synthesis model by involving Fourier encoding to enable gradient backpropagation to update conditional factor and demonstrate the effectiveness of our approach on the task of AD classification. 3. We consider the scenario of using generative models in a continual learning context and show that our approach can help alleviate catastrophic forgetting. 4. We apply the brain ageing synthesis model for brain rejuvenation synthesis and demonstrate that the proposed approach has the potential to alleviate spurious correlations.

Notations and problem overview
We denote an image as x X, and a conditional generative model G that takes an image x and a conditional vector v as input and generates a counterfactualx that corresponds to v: For each x, there is a label y Y. We define a classifier C that predicts the labelŷ for given x:ŷ ¼ C(x). In this paper, x is a brain image, y is the AD diagnosis of x, and v represents the target age a and AD diagnosis on which the generator G is conditioned. We select age and AD status to be conditioning factors as they are major contributors to brain ageing. We use a 2D slice brain ageing generative model as G, and a VGG 4 -based (22) AD classification model as C. In Xia et al. (21), the brain ageing generative model is evaluated in multiple ways, including several quantitative metrics: Structural Similarity (SSIM), Peak Signal-to-Noise Ratio (PSNR) and Mean Squared Error (MSE) between the synthetically aged brain images and the ground-truth followup images, and Predicted Age Difference (PAD), i.e. difference between the predicted age by a pre-trained age predictor and the desired target age. For more details of the evaluation metrics, please refer to Xia et al. (21), Section 4. Note that we only change the target age a in this paper, thus we write the generative process asx ¼ G(x, a) for simplicity.
Suppose a pre-trained G and a C are given, the question we want to answer is: "How can we use G to improve C in a (data) efficient manner"? To this end, we propose an approach to utilise G to improve C via an adversarial game with gradient backpropagation to update a and C alternatively and iteratively.

Fourier encoding for conditional factors
The proposed approach requires backpropagation of gradient to the conditional factor to find the hard counterfactuals. However, the original brain ageing synthesis model (21) used ordinal encoding to encode the conditional age and AD diagnosis, where the encoded vectors are discrete in nature and need to maintain a certain shape, which hinders gradient backpropagation to update these vectors.  A schematic of the adversarial classification training. The pre-trained generator G takes a brain image x and a target age a as input and outputs a synthetically aged imagex that corresponds to the target age a. The classifier C aims to predict AD label for a given brain image. To utilise G to improve C, we formulate an adversarial game between a (in red box) and C (in cyan box), where a and C are updated alternatively and iteratively using L 1 and L 2 , respectively (see Section 2.3). Note G is frozen. 4 A popular deep learning neural network that has widely been used for classification. that we first quantize to 0=1 and then check for ordinal order preservation of the 1 digits. Both are not easily differentiable.
To enable gradient backpropagation to update the conditional vectors, we propose to use Fourier encoding (23,24) to encode the conditional attributes, i.e., age and heath state (diagnosis of AD). The effectiveness of Fourier encoding has been experimentally shown in Tancik et al. (23), Mildenhall et al. (24). We also compared the generative model using Fourier v.s. Ordinal encoding using the quantitative metrics briefly introduced in Section 2.1, as presented in Table 1. We observe that the generator using Fourier encoding achieves very similar quantitative results as the generator using ordinal encoding, demonstrating effectiveness of Fourier encoding to encode age and health status.
The key idea of Fourier encoding is to map low-dimensional vectors to a higher dimensional domain using a set of sinusoids.
where b j can be viewed as the Fourier basis frequencies, and p 2 j the Fourier series coefficients. In this work, the vector v represents the target age a and the health status (AD diagnosis), and d ¼ 2. In our experiments, we set p 2 j ¼ 1 for j ¼ 1, . . . , m, and b j are independently and randomly sampled from a Gaussian distribution, b j N (m scale Ã I, 0), where m scale is set to 10. We set m ¼ 100 and the resulting g(v) is 200-dimensional. After encoding, the generator G takes the encoded vector g(v) as input.
The use of Fourier encoding offers two advantages. First, Xia et al. (21) encoded age and health state into two vectors and had to use two MLPs to embed the encoded vectors into the model. This may not be a big issue when the number of factors is small. However, extending the generative model to be conditioned on tens or hundreds of factors will increase the memory and computation costs significantly. With Fourier encoding, we can encode all possible factors into a single vector, which offers more flexibility to scale the model to multiple conditional factors. Second, Fourier encoding allows us to compute the gradients with respect to the input vector v or certain elements of v, since the encoding process is differentiable. As such, we replace the ordinal encoding with Fourier encoding for all experiments. The generative model G takes v as input:x ¼ G(x, v), where v represents target age and health state. Since we only change the target age a in this paper, we write the generative process asx ¼ G(x, a) for simplicity.

Adversarial counterfactual augmentation
Suppose we have a conditional generative model G and a classification model C. The goal is to utilise G to improve the performance of C. To this end, we propose an approach consisting of three steps: pre-training, hard sample selection and adversarial classification training. A schematic of the adversarial classification training is presented in Figure 1. Algorithm 1 summarises the steps of the method. Below we describe each step in detail.

Pre-training
The generative model is pre-trained using the same losses as in Xia et al. (21) except that we use Fourier encoding to encode age and AD diagnosis. Consequently, we obtain a pre-trained G that can generate counterfactuals conditioned on given target ages a: The classification model C is a VGG-based network (22) trained to predict the AD diagnosis from brain images, optimised by minimising: where L s (Á) is a supervised loss (binary cross-entropy loss in this paper), x is a brain image, and y is its ground-truth AD label. To note that if the pre-trained G and C are available in practice, we could avoid the pre-training step. For detail of the evaluation metrics please refer to Xia et al. (21), Section 4.
Input: Training set D train ; hyperparameter k, N; a pre-trained G; C. Pre-training: 1. Train the classifiers C on D train (Eq. 2). Hard sample selection: 2. Select N samples from D train that result in the highest classification errors for C, denoted as D hard . Adversarial classification training: 3. Randomly initialize target ages a, and obtain initial synthetic data.
For k do 4. Update a in the direction to maximize classification error (Eq 4). 5. Obtain synthetic images with D hard and the updated a, denoted as D syn .
6. Update C to optimize Eq. 5 on D train ∪ D syn for one epoch.

Hard sample selection
Liu et al. (25), Feldman and Zhang (26) suggested that training data samples have different influence on the training of a supervised model, i.e., some training data are harder for the task and are more effective to train the model than others. Liu et al. (25) propose to up-sample, i.e. duplicate, the hard samples as a way to improve the model performance. Based on these observations, we propose a similar strategy to Liu et al. (25) to select these hard samples: we record the classification errors of all training samples for the pre-trained C and then select N ¼ 100 samples with the highest errors. The selected hard samples are denoted as D hard : {X hard , Y hard }.  28) augmented datasets by randomly generating a number of synthetic data with pre-trained generators. Similar to training samples, some synthetic data could be more effective for downstream tasks than others. Here we assume that if a synthetic data sample is hard, then it is more effective for training. We propose an adversarial game to find the hard synthetic data to boost C.

Adversarial classification training
Specifically, let us first define the classification loss for synthetic data as: wherex is a generated sample conditioned on the target age a: x ¼ G(x, a), and y is the ground-truth AD label for x. Here we assume that changing target age does not change the AD status, thus x andx have the same AD label. Since the encoding of age a is differentiable (see Section 2.2), we can obtain the gradients of L C with respect to a as: r a L C ¼ r a [L s (C(G(x, a)), y)], and update a in the direction of maximising L C by:ã ¼ a þ g a r a L C , where g a is the step size (learning rate) for updating a. Formally, the optimization function of a can be written as: Then we could obtain a set of synthetic data using the updated a: The classifier C is updated by optimising: where D combined : {X combined , Y combined } is a combined dataset consisting of the training dataset and synthetic dataset: Liu et al. (25), we update C on D combined instead of D syn as we found updating C only on D syn can cause catastrophic forgetting (29). The adversarial game is formulated by alternatively and iteratively updating a and classifier C via Eqs. 4 and 5, respectively. In practice, to prevent a from going to unsuitable ages, we clip it to be in [60, 90] after every update.

Updating a vs. updating G
Note here the adversarial game is formulated between a and C, instead of G and C. This is because training G against C allows G to change its latent space without considering image quality, and the output of G could be unrealistic. Please refer to Section 4.1.2 for more details and results.

Counterfactual augmentation vs. conventional augmentation
Here we choose to augment data counterfactually instead of applying conventional augmentation techniques. This is because that the training and testing data are already pre-processed and registered to MNI 152, and in this case conventional augmentations do not introduce helpful variations. Please refer to Section 4.1.3 for more details and results.

Adversarial classification training in a continual learning context
Most previous works (14,27,28,(30)(31)(32) that used pretrained deep generative models for augmentation focused on generating a large number of synthetic samples, and then merged the synthetic data with the original dataset and trained the downstream task model (e.g. a classifier) on this augmented dataset. However, this requires training the task model from scratch, which could be inconvenient. Imagine that we are given a pre-trained classifier, and we have a generator at hand which may or may not be pre-trained on the same dataset. We would like to use the generator to improve the classifier, or transfer the knowledge learnt by the generator to the classifier. The strategy of previous works is to use the generative model to produce a large amount of synthetic data that cover the knowledge learnt by the generator, and then train the classifier on both real and synthetic data from scratch, which would be expensive. However, in this work, we consider the task of transferring knowledge from the generator to the classifier in the continual learning context, by considering synthetic data as new samples. We want the classifier to learn new knowledge from these synthetic data without forgetting what it has learnt from the original classification training set. We will show how our approach can help in the continual learning context.
In Section 2.3, after we obtain the synthetic set D syn , we choose to update the classifier C on the augmented dataset D syn < D train , instead of D syn (stage 6 in Algorithm 1). This is because re-training the classifier only on the D syn would result in catastrophic forgetting (29), i.e. a phenomenon where deep neural networks tends to forget what it has learnt from previous data when being trained on new data samples. To alleviate catastrophic forgetting, efforts have been devoted to developing approaches to allow artificial neural networks to learn in a sequential manner (33,34). These approaches are known as continual learning (33,35,36), lifelong learning (37, 38), sequential learning (39,40), or incremental learning (41, 42). Despite different names and focuses, the main purpose of these approaches is to overcome catastrophic forgetting and to learn in a sequential manner.
If we consider the generated data as new samples, then the update of the pre-trained classifier C can be viewed as a continual learning problem, i.e. how to learn new knowledge from the synthetic set D syn without forgetting old knowledge that is learnt from the original training data D train . To alleviate catastrophic forgetting, we re-train the classifier on both the synthetic dataset D syn and the original training dataset D train . This strategy is known as memory replay in continual learning (43,44) and was also used in other augmentation works (25). The key idea is to store previous data in a memory buffer and replay the saved data to the model when training on new data. However, it could be expensive to store and revisit all the training data, especially when the data size is large (44). In Section 4.2, we perform experiments where we only provide a portion (M%) of training data to the classifier when re-training with synthetic data (to simulate the memory buffer). In this case, we only create synthetic data from the memory bank. We want to see whether catastrophic forgetting would happen or not when only a portion (M%) of training data is provided, and if so, how much it affects the test accuracies. Algorithm 2 summarises the steps of the method in the continual learning context.

Data
We use the ADNI dataset (45) for experiments. We select 380 AD and 380 CN (control normal) T1 volumes between 60 and 90 years old. We split the AD and CN data into training/validation/ testing sets with 260/40/260 volumes, respectively. All volumetric data are skull-stripped using DeepBrain 5 , and linearly registered to MNI 152 space using FSL-FLIRT (46). We normalise brain

Implementation
The generator is trained the same way as in Xia et al. (21), except we replace ordinal encoding with Fourier encoding. We pre-train the classifier for 100 epochs. The experiments are performed using Keras and Tensorflow. We train pre-trained classifiers C with Adam with a learning rate of 0.00001 and decay of 0.0001. During adversarial learning, the step size of a is tuned to be 0.01, and the learning rate for C is 0.00001. The experiments are performed using a NVIDIA Titan X GPU.

Comparison methods
We compare with the following baselines: 1. Naïve: We directly use the pre-trained C for comparison as the lower bound. 2. RSRS: Random Selection + Random Synthesis. We randomly select N ¼ 100 samples from the training set D train , denoted as D rand , and then use the generator G to randomly generate N synthesis ¼ 5 synthetic samples for each sample in D rand , denoted as D syn . Then we train the classifier on the combined dataset D train < D syn for k ¼ 5 steps. This is the typical strategy used by most previous works (14,27,28). 3. HSRS: Hard Selection + Random Synthesis. We select N ¼ 100 hard samples from D train based on their classification errors of C, denoted as D hard , and then use the generator G to randomly generate N synthesis ¼ 5 synthetic samples for each sample in D hard , denoted as D syn . Then we train the classifier on the combined dataset D train < D syn for k ¼ 5 steps.
Input: Training dataset D train ; hyperparameter M, N, k; a pre-trained generator G; a pre-trained classifier model C. Construct D store : 1. Randomly select M% data from D train , denoted as D store . Hard sample selection: 2. Select N samples from D store that result in the highest classification errors for C, denoted as D hard . Adversarial training: 3. Randomly initialize target ages a, and obtain initial synthetic data. Frontiers in Radiology 4. RSAT: Random Selection + Adversarial Training. We randomly select N ¼ 100 samples from the training set D train , denoted as D rand , and then use the adversarial training strategy to update the classifier C, as described in Section 2.3. The difference between RSAT and our approach is that we select hard samples for generating counterfactuals, while RSAT uses random samples.

. Comparison with baselines
We first compare our method with baseline approaches by evaluating the test accuracy of the classifiers. We set N ¼ 100 and k ¼ 5 in experiments. We pre-train C for 100 epochs and G as described in Section 3. The weights of the pre-trained C and the pre-trained G are the same for all methods. For a fair comparison, the total number of synthetically generated samples is fixed to 500 for RSRS, HSRS, RSAT and our approach. For JTT, there are 2,184 samples mis-classified by C and oversampled. We initialize a randomly between real ages of original brain images x and maximal age (90 yrs old).
From Table 2 we can observe that our proposed procedure achieves the best overall test accuracy, followed by baseline RSAT. This demonstrates the advantage of adversarial training between the conditional factor (target age) a and the classifier. On top of that, it shows that selecting hard examples for creating augmented synthetic results helps, which is also demonstrated by the improvement of performance of HSRS over Naïve. We also observe that JTT (25) improves the classifier performance over Naïve, showing the benefit of upsampling hard samples. In contrast, baseline RSRS achieves the lowest overall test accuracy, even lower than that of Naïve. This shows that randomly synthesising counterfactuals from randomly selected samples could result in synthetic images that are harmful to the classifier.
Furthermore, we observe that for all methods, the worstgroup performances are achieved on the 80-90 CN group. A potential reason could be: as age increases, the brains shrink, and it is harder to tell if the ageing pattern is due to AD or caused by normal ageing. Nevertheless, we observe that for this worst group, our proposed method still achieves the best performance, followed by RSAT. This shows that adversarial training can be helpful to improve the performance of the classifier, especially for hard groups. The next best results are achieved by HSRS and JTT, which shows that finding hard samples and up-sampling or augmenting them was helpful to improve the worst-group performance. We also observe the improvement of worst-group performance for RSRS over Naïve, but the improvement is small compared to other baselines. Figure 2 presents histograms of original ages for training subjects and the target ages after adversarial training, where we can see how the adversarial training aims to balance the data.
We also report the precision and recall for all methods, as presented in Table 3. We can observe that our approach achieves the highest overall precision and recall results.
In summary, the quantitative results show that it is helpful to find and utilise hard counterfactuals for improving the classifier.

Train G against C
We choose to formulate an adversarial game between the conditional generative factor a (the target age) and the classifier C, instead of between the generator G and the classifier C. This is because we are concerned that an adversarial game between G and C could result in unrealistic outputs of G. In this section, we perform an experiment to investigate this.
Specifically, we define an optimization function: x Xtrain,y Ytrain L s (C (G(x, a)), y), where we aim to train G in the direction of maximising the loss of the classifier C on the synthetic data G(x, a).
After every update of G, we construct a synthetic set D syn by generating 100 synthetic images from D train , and update C on D train < D syn via Equation 5. The adversarial game G vs. C is formulated by alternatively optimising Equations 6 and 5 for 10 epochs.
In Figure 3, we present the synthetic brain ageing progression of a CN subject before and after the adversarial training of G vs. C. We can observe that after the adversarial training, the generator G produces unrealistic results. This could be because there is no loss or constraint to prevent the generator G from producing low-quality results. The adversarial game only requires the generator G to produce images that are hard for the classifier C, and naturally, images of low quality would be hard for C. A potential solution could be to involve a GAN loss with a discriminator to improve the output quality, but this would make the training much more complex and require more memory and computations. We also measure the test accuracy of the classifier C after training G against C to be 81:6%, which is much lower than the Naïve method (88:4%) and our approach (91:1%) in Table 2. The potential reason is that C is misled by the unrealistic samples generated by G.

Effect of conventional augmentations for registered brain MRI data
In this section, we test the effect of applying several commonly used conventional augmentations, e.g. rotation, shift, scale and flip, to the training of the AD classifier. These are typical conventional augmentation techniques applied to computer vision classification task. Specifically, we train the classifier the same way as Naïve, except we augment training data with conventional augmentations.
Interestingly, we find that after applying rotation (range 10 degrees), shift (range 0.2), scale (range 0.2), and flip to augment the training data, the accuracy of the trained classifier drops from 88:4% to 71:6%. We then measure accuracies when trained with each augmentation to be 74:1% (rotation), 87:1% (shift), 82:9% (scale), and 87:8% (flip). We also trained the Histograms of ages of subjects before and after adversarial learning. We can observe that adversarial training aims to balance the data. We first present the precision for different age groups (column 2-4) and all testing data (column 5), and then present the recall for different age groups (column [6][7][8] and all testing data (column 9). For each group, the best results are shown in bold.
Xia et al. 10.3389/fradi.2022.1039160 Frontiers in Radiology classifier with random gamma correction (gamma ranges from 0.2 to 1.8), and the resulting test accuracy is 84:4%. This could be because both training and testing data are already preprocessed, including registered to MNI 152 and contrast normalisation, and these conventional augmentations do not introduce helpful variations to the training data but distract the classifier from focusing on subtle differences between AD and CN brains. We also tried to train the classifier with MaxUp (10) with conventional augmentations. The idea of MaxUp is to generate a small batch of augmented samples for each training sample and train the classifier on the worstperformance augmented sample. The overall test accuracy is 57:7%. This could be because that MaxUp tends to select the augmentations that distract the classifier from focusing on subtle AD features the most.
The results with conventional augmentations (+MaxUp) suggest that for the task of AD classification, when training and testing data are pre-processed well, conventional data augmentation techniques seem to not help improve the classification performance. Instead, these augmentations distract the classifier from identifying subtle changes between CN and AD brains. By contrast, the proposed procedure augment data in terms of semantic information, which could alleviate data imbalance and improve classification performance.

4.2.
Adversarial counterfactual augmentation in a continual learning context 4.2.1. Results when re-training with a portion (M%) of training data Suppose we have a pre-trained classifier C and a pre-trained generator G, and we want to improve C by using G for data augmentation. However, after pre-training, we only store M% (M [ (0, 100]) of the training dataset, denoted as D store . During the adversarial training, we synthesise N samples using the generator G, denoted as D syn . Then we update the classifier C on D store < D syn , using Equation 5 where D combined ¼ D store < D syn . The target ages are initialised and updated the same way as in Section 4.1. Algorithm 2 illustrates the procedure in this section. Table 4 presents the test accuracies of our approach and baselines when M changes. For Naïve-100, the results are then same as in Table 2. For JTT, the original paper Liu et al. (25) retrained the classifier using the whole training set. Here we first randomly select M% training samples as D store and find misclassified data D mis within D store to up-sample, then we retrain the classifier on the augmented set. We can observe that when M decreases, catastrophic forgetting happens for all The synthetic results for a healthy (CN) subject x at age 70: (A) the results of the pre-trained G, i.e. before we train G against C; (B) the results of G after we train G against C. We synthesise aged imagesx at different target ages a. We also visualise the difference between x andx, jx À xj. For more details see text.
Xia et al. 10.3389/fradi.2022.1039160 Frontiers in Radiology approaches. However, our method suffers the least from catastrophic forgetting, especially when M is small. With M ¼ 20% of training data for retraining, our approach achieves better results than Naïve. This might be because the adversarial training between a and C tries to detect what is missing in D store and tries to recover the missing data by updating a towards those directions. We observe that RSAT achieves the second best results, only slightly worse than the proposed approach. Moreover, HSRS and JTT are more affected by catastrophic forgetting and achieve worse results. This might be because the importance of selecting hard samples declines as M decreases, since the D store becomes smaller. These results demonstrate that our approach could alleviate catastrophic forgetting. This could be helpful in cases where we want to utilise generative models to improve pre-trained classifiers (or other task models) without revisiting all the training data (a continual learning context).

Results when number of samples used for synthesis (N) changes
We also performed experiments where we changed N, i.e. the number of samples used for generating counterfactuals. Specifically, we set M ¼ 1, i.e. only 1% of original training data are used for re-training C, to see how many synthetic samples are needed to maintain good accuracy, especially when there are only a few training data stored in D store . This is to see how efficient the synthetic samples are in terms of training C and alleviating catastrophic forgetting. The results are presented in Table 5.
From Table 5, we can observe that the best results are achieved by our method, followed by RSAT. Even with only one sample for synthesis, our method could still achieve a test accuracy of 80%. This is probably because the adversarial training of a vs. C guides G to generate hard counterfactuals, which are efficient to train the classifier. The results demonstrate that our approach could help alleviate catastrophic forgetting even with a small number of synthetic samples used for augmentation. This experiment could also be viewed as a measurement of the sample efficiency, i.e. how efficient a synthetic sample is in terms of re-training a classifier.

Can the proposed procedure alleviate spurious correlations?
Spurious correlation occurs when two factors appear to be correlated to each other but in fact they are not (47). Spurious correlation could affect the performance of deep neural networks and has been actively studied in computer vision field (25,(48)(49)(50)(51) and in medical imaging analysis field (52,53). For instance, suppose we have an dataset of bird and bat photos. For bird photos, most backgrounds are sky. For bat photos, most backgrounds are cave. If a classifier learns this spurious correlation, e.g. it classifies a photo as bird as long as the background is sky, then it will perform poorly on images where bats are flying in the sky. In this section, we investigate if our approach could correct such spurious correlations by changing a to generate hard counterfactuals.
Here we create a dataset where 7860 images between 60 and 75 yrs old are AD, and 7,680 images between 75 and 90 yrs old are healthy, denoted as D spurious . This is to construct a spurious correlation: young ! AD and old ! CN (in reality older people have higher chances of getting AD (54)). Then we pre-train C on D spurious . The brain ageing model proposed in Xia et al. (21) only considered simulating ageing process, but did not consider brain rejuvenation, i.e., the reverse of ageing. To utilise old CN data, we pre-train another generator in the rejuvenation direction, i.e., generating younger brain images from old ones. As a result, we obtain two generators that are pre-trained on D train , denoted as G ageing and G rejuve , where G rejuve is trained to simulate the rejuvenation process. Figure 4 shows visual results of G rejuve . After that, we select 50 CN and 50 AD hard images from D spurious , denoted as D hard and perform the adversarial classification training using G rejuve for old CN samples and G ageing for young AD samples. The target ages a are initialized as real ages of x.
After we obtain G ageing and G rejuvenation , we select 50 CN and 50 AD images from D spurious that result in highest training errors, denoted as D hard . Note that the selected CN images are We also show the percentage of N vs. the total number of Dstore. between 75 and 90 yrs old, and the AD images are between 60 and 75 yrs old. Then we generate synthetic images from D hard using G rejuvenation for old CN samples and G ageing for young AD samples. The target ages a are initialized as their groundtruth ages. Finally, we perform the adversarial training between a and the classifier C. Here we want to see if the adversarial training can detect the spurious correlations purposely created by us, and more importantly, we want to see if the adversarial training between a and C can break the spurious correlations. Table 6 presents the test accuracies of our approach and baselines. For Naïve, we directly use the classifier C pretrained on D spurious . For HSRS, we randomly generate synthetic samples from D hard for augmentation. For JTT, we simply select mis-classified samples from D spurious and upsample these samples.
We can observe from Table 6 that the pre-trained C on D spurious (Naïve) achieves much worse performance (67:0% accuracy) compared to that of Table 2 (88:4% accuracy). Specifically, it tends to misclassify young CN images as AD and misclassify old AD images as CN. This is likely due to the spurious correlations that we purposely create in D spurious : young ! AD and old ! CN. We notice that for Naïve, the test accuracies of AD groups are higher than that of CN groups. This is likely due to the fact we have more AD training data, and the classifier is biased to classify a subject to AD. This can be viewed as another spurious correlation. Overall, we observe that our method achieves the best results, followed by HSRS. This shows that the synthetic results generated by the generators are helpful to alleviate the effect of spurious correlations and improve downstream tasks. The improvement of our approach over HSRS is due to the adversarial training between a and C, which guides the generator to produce hard counterfactuals. We observe JTT does not improve the test accuracies significantly. A potential reason is that JTT tries to find "hard" samples in the training dataset. However, in this experiment, the "hard" samples should be young CN and old AD samples which do not exist in the training dataset D spurious . By contrast, our procedure could guide G to generate these samples, and HSRS could create these samples by random chance. Figure 5 plots the histograms of the target ages a before and after the adversarial training. From Figure 5 we can observe that the adversarial training pushes a towards the hard direction, which could alleviate the spurious correlations. For instance, in D spurious and D hard the AD subjects are all in the young group, i.e. 60-75 yrs old, and the classifier learns the spurious correlation: young ! AD, but in Figure 5A we can observe We first present the average test accuracies for different age groups with CN diagnosis (column 2-3) or AD (column [4][5], and then present the average test accuracies for the whole testing set (column 6). For each method, the worst-group performance is shown in italic. For each age group, i.e. each column, the best performance was shown in bold. For more details see text. Example results of brain rejuvenation for an image (x) of a 85 year old CN subject. We synthesise rejuvenated imagesx at different target ages a. We also show the differences betweenx and x,x À x. For more details see text.
Xia et al. 10.3389/fradi.2022.1039160 Frontiers in Radiology that the adversarial training learns to generate AD synthetic images in the range of 75-90 yrs old. These old AD synthetic images can help alleviate the spurious correlation and improve the performance of C. Similarly, we can observe a are pushed towards young for CN subjects in Figure 5B.

Conclusion
We presented a novel adversarial counterfactual scheme to utilise conditional generative models for downstream tasks, e.g. classification. The proposed procedure formulates an adversarial game between the conditional factor of a pretrained generative model and the downstream classifier. The synthesis model used in this work uses two generators for ageing and rejuvenation. Others have shown that one model can handle both tasks albeit in another dataset and with less conditioning factors (55). We do highlight though that our approach is agnostic to the generator used and since could benefit from advances in (conditional) generative modelling. In this paper, we demonstrate that several conventional augmentation techniques are not helpful for registered MRI. However, there might be other heuristic-based augmentation techniques that will improve performance, and it is worth trying to combine our semantic augmentation strategy with such conventional augmentation techniques to further boost performance. The proposed adversarial counterfactual scheme could be applied to generative models that produced other types of counterfactuals rather than the ageing brain, e.g. the ageing heart (55, 56), future disease outcomes (57), existence of pathology (58,59), etc. The way we updated the conditional factor (target age) could be improved. Instead of a continuous scalar (target age), we can consider extending the proposed adversarial counterfactual augmentation to update other types of conditional factors, e.g., discrete factor or image. The strategy that we used to select hard samples may not be the most effective and could be improved.

Data availability statement
Publicly available datasets were analyzed in this study. This data can be found here: https://adni.loni.usc.edu.

Ethics statement
Ethical review and approval was not required for this study in accordance with the local legislation and institutional requirements.

Author contributions
TX, PS, CQ and SAT contributed to the conceptualization of this work. TX, PS, CQ and SAT designed the methodology. TX developed the software tools necessary for preprocessing and analysing images files and for training the model. TX drafted this manuscript. All authors contributed to the article and approved the submitted version. Histograms of target ages a before and after adversarial training: (A) the histogram of a for the 50 AD subjects in D hard ; (B) the histogram of a for the 50 CN subjects in D hard . Here we show histograms of a before (in orange) and after (in blue) the adversarial training.