STCA-SNN: self-attention-based temporal-channel joint attention for spiking neural networks

Spiking Neural Networks (SNNs) have shown great promise in processing spatio-temporal information compared to Artificial Neural Networks (ANNs). However, there remains a performance gap between SNNs and ANNs, which impedes the practical application of SNNs. With intrinsic event-triggered property and temporal dynamics, SNNs have the potential to effectively extract spatio-temporal features from event streams. To leverage the temporal potential of SNNs, we propose a self-attention-based temporal-channel joint attention SNN (STCA-SNN) with end-to-end training, which infers attention weights along both temporal and channel dimensions concurrently. It models global temporal and channel information correlations with self-attention, enabling the network to learn ‘what’ and ‘when’ to attend simultaneously. Our experimental results show that STCA-SNNs achieve better performance on N-MNIST (99.67%), CIFAR10-DVS (81.6%), and N-Caltech 101 (80.88%) compared with the state-of-the-art SNNs. Meanwhile, our ablation study demonstrates that STCA-SNNs improve the accuracy of event stream classification tasks.

Frontiers in Neuroscience 02 frontiersin.org ANNs.Recently, ANNs' modules (Hu et al., 2021;Yang et al., 2021;Yao et al., 2021Yao et al., , 2023c) ) have been integrated into SNNs to improve the performance of SNNs.CSNN (Xu et al., 2018) first validated the application of convolution structure on SNNs, promoting the development of SNNs.Convolution-based SNNs share weights across both temporal and spatial dimensions, following the assumption of spatio-temporal invariance (Huang et al., 2022).This approach can be regarded as a local way of information extraction since convolutional operations can only process a local neighborhood at a time, either in space or time.However, when dealing with sequential data like event streams, capturing long-distance dependencies is of central importance to modeling complex temporal dynamics.Non-local operations (Wang et al., 2018) provided a solution as a building block by computing the response at a position as a weighted sum of the features at all positions.The range of positions can span across space, time, or spacetime, allowing non-local operators to achieve remarkable success in vision attention.
The attention mechanism is inspired by the human ability to selectively find prominent areas in complex scenes (Itti et al., 1998).A popular research direction is to present attention as a lightweight auxiliary unit to improve the representation power of the basic model.In the ANNs domain, Ba et al. (2014) first introduced the term "visual attention" for image classification tasks, utilizing attention to identify relevant regions and locations within the input image.This approach also reduces the computational complexity of the proposed model regarding the size of the input image.SENet (Hu et al., 2018) was introduced to reweight the channel-wise responses of the convolutional features, determining "what" to pay attention to.CBAM (Woo et al., 2018) inferred attention maps sequentially along channelwise and spatial dimensions for refining the input feature, determining "what" and "where" to pay attention to concurrently.In the SNNs domain, TA-SNN (Yao et al., 2021) first extended the channel-wise attention concept to temporal-wise attention and integrated it into SNNs to determine 'when' to pay attention.MA-SNN (Yao et al., 2023c) extended CBAM to SNNs and proposed a multi-dimensional attention module along temporal-wise, channel-wise, and spatial-wise separately or simultaneously.Recently, TCJA-SNN (Zhu et al., 2022) cooperated temporal-wise and channel-wise attention correlations using the 1-D convolution operation to present the correlation between time-steps and channels.However, the receptive field of TCJA-SNN is a local cross shape that is restricted by its convolution kernels, shown in Figure 1A.Thus long-range dependencies can only be captured when 1-D convolution operation is repeated, which makes multi-hop dependency modeling difficult.On the other hand, self-attention, another vital feature of the human biological system, possesses the ability to capture feature dependencies effectively as an additional non-local operator alongside SE and CBAM.It has sparked a significant wave of interest and achieved remarkable success in various tasks (Vaswani et al., 2017;Dosovitskiy et al., 2020;Liu et al., 2021).Intuitively, there is a compelling interest in investigating the application of self-attention in SNNs to advance deep learning, when considering the biological characteristics of both mechanisms (Yao et al., 2023a,b;Zhou C. et al., 2023;Zhou Z. et al., 2023).
To address the local spatio-temporal receptive field limitation of TCJA, we first adopt self-attention, a non-local operation, to model global temporal and channel information correlations.The selfattention module we employed can capture the global spatio-temporal receptive field, as shown in Figure 1B, allowing for the direct long-range dependencies modeling, which is the highlight of our work.We propose a plug-and-play Self-attention-based Temporal-Channel joint Attention (STCA) module for SNNs with end-to-end training.The STCA-SNNs can learn to focus on different features of the input at each time-step.In other words, the STCA-SNNs can learn 'when' and 'what' to attend concurrently, enhancing the ability of the SNNs to process temporal information.We evaluated the effectiveness of STCA-SNNs across different architectures on three benchmark event stream classification datasets: N-MNIST, CIFAR10-DVS, and N-Caltech 101.Our detailed experiments show that STCA-SNNs achieve competitive accuracy with existing state-of-the-art SNNs.
The main contributions of our work are summarized as follows: 1. We propose STCA-SNNs for event streams that can undertake end-to-end training and inference tasks.2. The plug-and-play STCA module models global temporal and channel correlations with self-attention, allowing the network to learn 'when' and 'what' to attend simultaneously.This enhances the ability of SNNs to process temporal information.3. We evaluate the performance of STCA-SNNs on three benchmark event stream classification datasets, N-MNIST, CIFAR10DVS, and N-Caltech 101.Our experimental results demonstrate that STCA-SNNs achieve competitive accuracy compared to existing state-of-the-art SNNs.

Related work 2.1. Attention in SNNs
Spiking neural networks benefit from biological plausibility and continuously pursue the combination with brain mechanisms.The attention mechanism draws inspiration from the human ability to selectively identify salient regions within complex scenes and has gained remarkable success in deep learning by allocating attention weights preferentially to the most informative input components.A popular research direction is to present attention as an auxiliary module that can be easily integrated with existing architectures to boost the representation power of the basic model (Hu et al., 2018;Woo et al., 2018;Guo et al., 2022;Li et al., 2022).Yao et al. (2021) first suggested using an extra plug-and-play temporal-wise attention module for SNNs to bypass a few unnecessary input timesteps.Then they proposed a multi-dimensional attention module along temporalwise, channel-wise, and spatial-wise separately or simultaneously to optimize membrane potentials, which in turn regulate the spiking response (Yao et al., 2023c).STSC-SNN (Yu et al., 2022) employed temporal convolution and attention mechanisms to improve spatiotemporal receptive fields of synaptic connections.SCTFA-SNN (Cai et al., 2023) computed channel-wise and spatial-wise attention separately to optimize membrane potentials along the temporal dimension.Yao et al. (2023a,b) recently proposed an advanced spatial attention module to harness SNNs' redundancy, which can adaptively optimize their membrane potential distribution by a pair of individual spatial attention sub-modules.TCJA-SNN (Zhu et al., 2022) cooperated temporal-wise joint channel-wise attention correlations using 1-D convolution operation.However, the temporal-channel receptive field of TCJA is a local cross shape that is restricted by its convolution kernels, requiring multiple repeated computations to establish long-range dependencies of features.Therefore, it is computationally inefficient and makes multi-hop dependency modeling difficult.
Among the attention mechanisms, self-attention, as another important feature of the human biological system, possesses the ability to capture feature dependencies.Originally developed for natural language processing (Vaswani et al., 2017), self-attention has been extended to computer vision, where it has achieved significant success in various applications.The self-attention module can also be considered a building block of CNN architectures, which are known for their limited scalability when it comes to large receptive fields (Han et al., 2022).In contrast to the progressive behavior of convolution operation, self-attention can capture long-range dependencies directly by computing interactions between any two positions, regardless of their positional distance.Moreover, it is commonly integrated into the top of the networks to enhance highlevel semantic features for vision tasks.Recently, an emerging research direction is to explore the biological characteristics associated with the fusion of self-attention and SNNs (Yao et al., 2023a,b;Zhou C. et al., 2023;Zhou Z. et al., 2023).These efforts primarily revolve around optimizing the computation of self-attention within SNNs by circumventing multiplicative operations, leading to performance degradation.Diverging from these studies, our primary goal is to explore how self-attention can enhance the spatio-temporal information processing capabilities of SNNs.

Learning algorithms for SNNs
Existing SNN training methods can be roughly divided into three categories: 1) the biologically plausible method, 2) the conversion method, and 3) the gradient-based direct training method.The first one is based on biological plausible local learning rules, like spike timing dependent plasticity (STDP) (Diehl and Cook, 2015;Kheradpisheh et al., 2018) and ReSuMe (Ponulak and Kasinski, 2010), but achieving high performance for deep networks is challenging.The conversion method offers an alternative way to obtain highperformance SNNs by converting a well-trained ANN and mapping its parameters to an SNN with an equivalent architecture, where the firing rate of the SNN acts as ReLU activation (Cao et al., 2015;Rueckauer et al., 2017;Sengupta et al., 2019;Ding et al., 2021;Bu et al., 2022;Wu et al., 2023).Moreover, some works explored postconversion fine-tuning of converted SNNs to reduce latency and increase accuracy (Rathi et al., 2020;Rathi and Roy, 2021;Wu et al., 2021).However, this method is not suitable for neuromorphic datasets.The gradient-based direct training methods primarily include voltage gradient-based (Zhang et al., 2020), timing gradientbased (Zhang et al., 2021), and activation gradient-based approaches.Among them, the activation gradient-based method demonstrates notable effectiveness when performing challenging tasks.This approach uses surrogate gradients to address the non-differentiable spike activity issue, allowing for error back-propagation through time (BPTT) to interface with gradient descent directly on SNNs for end-to-end training (Neftci et al., 2019;Wu et al., 2019;Yang et al., 2021;Zenke and Vogels, 2021).These efforts have shown strong potential in achieving high performance by exploiting spatio-temporal information.However, further research is required to determine how to make better use of spatio-temporal data and how to efficiently extract spatio-temporal features.This is what we want to contribute.

Materials and methods
In this section, we first present the representation of event streams and the adopted spiking neuron model and later propose our STCA module based on this neuron model.Finally, we introduce the training method adopted in this paper.

Representation of event streams
An event, e, encodes three pieces of information: the pixel location (x, y) of the event, the timestamp t′ recording the time when the event is triggered, and the polarity of each single event p ∈ {−1, +1} reflecting an increase or decrease of brightness via +1/−1.Formally, a set of events at the timestamp t′can be defined as: Assume the spatial resolution is h × w, the event set equals to the spike pattern tensor S t′ ∈R 2 × h × w at the timestamp t′.However, processing these events one by one can be inefficient due to the limited amount of information contained in a single event.We follow the frame-based representation in SpikingJelly (Fang et al., 2020) that transforms event streams into high-rate frame sequences during preprocessing.Each frame includes many blank (zero) areas, and SNNs can skip the computation of the zero areas in each input frame (Roy et al., 2019), improving overall efficiency.

Spiking neural models
Spiking neuron in SNNs integrates synaptic inputs from the previous layer and the residual membrane potential into the latest membrane potential.The Parametric Leaky integrate-and-fire (PLIF)  (Fang et al., 2021).The subthreshold dynamics of the PLIF neuron is defined as: where V (t) indicates the membrane potential of the neuron at time t, τ is the membrane time constant that controls the decay of V (t), X (t) is the input collected from the presynaptic neurons and V rest is the resting potential.When the membrane potential V (t) exceeds the neuron threshold at time t, the neuron will emit a spike, and then the membrane potential goes back to a reset value V rest .We set V rest = V reset = 0.The iterative representation of the PLIF model can be described as follows: where superscripts t and l indicate the time step and layer index.
To avoid confusion, we use H t,l and V t,l to represent the membrane potential after neuronal dynamics and after the trigger of a spike in layer l at time-step t, respectively.V th is the firing threshold.S t,l is determined by ( ) , the Heaviside step function that outputs 1 if x ≥ 0 or 0 otherwise.The time constant τ = 1/k(a), k(a) is a sigmoid function 1/(1 + exp(−a)) with a trainable parameter a.

Self-attention-based temporal-channel joint attention module
The processing of temporal information in SNNs is generally attributed to spiking neurons because their dynamics naturally depend on the temporal dimension.However, the LIF neuron and its variants including the PLIF neuron, only sustain very weak temporal linkages.Additionally, event streams are inherently time-dependent therefore, it is necessary to establish spatial-temporal correlations to improve data utilization.The focus of this work is to model temporalwise and channel-wise attention correlations globally by adopting a self-attention mechanism.We present our idea of attention with a pluggable module termed the Self-attention-based Temporal-Channel joint Attention (STCA), which is depicted in Figure 2.
Formally, we collect intermediate the spatial feature of l-th layer at all time-steps as the input of STCA module, where T is time-step, C denotes channels, H and W are height and width of the feature, respectively.The spatial feature X t,l can be extracted from the original input S t,l : where BN (•) and Conv (•) mean the batch normalization and convolutional operation, W l is the weight matrix, S t, l-1 (l ≠ 1) is a spike tensor that only contains 0 and 1, and X R , ∈ × × .To simplify the notation, bias terms are omitted.BN is a default operation following the Conv, we also omit it in the rest of this paper.Since each spatial feature X t,l in X l is time-dependent, our idea of attention is to utilize the temporal correlation of these features.It is well known that each channel of feature maps corresponds to a specific visual pattern.
Our STCA module aims to determine 'when' to attend to 'what' are semantic attributes of the given input.For efficiency, STCA only focuses on temporal and channel modeling, the spatial information of the feature is aggregated by using both avg-pooling and max-pooling operations as follows: where AvgPool (•) and MaxPool (•) represent the outputs of the avg-pooling and max-pooling layer respectively, R l ∈R T × C .The generated different temporal-channel context descriptors, avg-pooled features and max-pooled features, are merged and then fed into a selfattention (SA) block.We follow the convention (Wang et al., 2018) to formulate the SA block, where the input feature in layer l is R l ∈R T × C , and the output feature is generated as: where r i ∈R 1 × C and a i ∈R 1 × C indicate the i th position of the input feature R l and output feature A l , respectively.Subscript j is the index that enumerates all positions along the temporal domain, i.e., i, j∈[1,2,…, T], and a pairwise function f (•) computes a representing relationship between i and all j.The function g (•) computes a representation of the input signal at time-step j, and the response is normalized by a factor C (r i ).We use a simple extension of the Gaussian function to compute the similarity in an embedding space, and the function f (•) can be formulated as: where θ (•) and ϕ (•) can be any embedding layers.If we consider the θ (•), ϕ (•), g (•) in the form of linear embedding: where w θ,i ∈R C × 1 is the i th row of the weight matrix W θ .For a given index i, , becomes the softmax output along the dimension j.The formulation can be future rewritten as: where A l ∈R T × C is the output feature of the same size as R l .Given the query, key, and value representations: Once , and W V ∈R C × C , Eq. 9 can be formulated as: In this way, the SA block is constructed.Then we employ a residual connection around the SA block.Finally, the attention process of STCA can be formulated as: where f = σ(R l + A l ) ∈ R T × C is the weight vector of STCA, ⊙ is element-wise multiplication, σ is the sigmoid function, and X l STCA ∈R T × C × H × W denotes the feature extracted by the STCA module along temporal and channel dimensions.

Training
We integrate the STCA module into networks and utilize the BPTT method to train SNNs.Since the process of neuron firing is non-differentiable, we use the derived ATan surrogate function For a given input with label n, the neuron that represents class n has the highest excitatory level while other neurons remain silent.So the target output is defined by Y = [y t, i ] with y t, i = 1 for i = n, and y t, i = 0 for i ≠ n.Then the loss function is described by the spike mean squared error: where O = [o t, i ] is the average spiking events of neurons under the voting strategy.

. Implementation details
We implement our experiments with the Pytorch package and SpikingJelly framework.All experiments were conducted using the BPTT learning algorithm on 4 NVIDIA RTX 2080 Ti GPUs.We utilized the Adam optimizer (Kingma and Ba, 2015) to accelerate the training process and implemented some standard training techniques of deep learning such as batch normalization and dropout.The corresponding hyper-parameters and SNN hyper-parameters are shown in Table 1.We verify our method on the following DVS benchmarks: CIFAR10-DVS contains 10 K DVS images of 10 classes recorded with the dynamic vision sensor from the original static CIFAR10 dataset.We apply a 9: 1 train-valid split (i.e., 9 k training images and 1 k validation images).The resolution is 128 × 128, we resize all of them to 48 × 48 in our training and we integrate the event data into 10 frames per sample (Li et al., 2017).
N-Caltech 101 dataset contains 8,831 DVS images converted from the original version of Caltech 101 with a slight change in object classes to avoid confusion.The N-Caltech 101 consists of 100 object classes plus one background class.Similarly, we apply the 9: 1 traintest split as CIFAR10-DVS.We use the SpikingJelly (Fang et al., 2020) package to process the data and integrate them into 14 frames per sample (Orchard et al., 2015).The neuromorphic MNIST dataset is a converted dataset from the original static MNIST dataset (Orchard et al., 2015).It contains 50 K training images and 10 K validation images.We integrate the event data into 10 frames per sample using SpikingJelly (Fang et al., 2020) package.

Networks
The network structures with STCA for different datasets are provided in Table 2 and the network architectures we use have been proven to perform quite well on each dataset.Specifically, for the CIFAR10-DVS dataset, we adopt a VGG11-like architecture.To mitigate the apparent overfitting on the CIFAR10-DVS dataset, we adopt the neuromorphic data augmentation, including horizontal Flipping and Mixup in each frame, which is also used in Zhu et al.     binary spike.Among them, some works (Wu et al., 2019;Yao et al., 2021) replace binary spikes with floating-point spikes and maintain the same forward pipeline as SNNs to obtain enhanced classification accuracy.STCA-SNNs achieve better performance than existing stateof-the-art SNNs on all datasets.We first compare our method on the CIFAR10-DVS dataset.We continue to utilize MSE the loss function and the same network architecture as TCJA-SNN (Zhu et al., 2022) and STSC-SNN (Yu et al., 2022) to preserve the consistency of this work, and our method reaches 81.6% top-1 accuracy, improving the accuracy by 0.9% over TCJA- SNN (Zhu et al., 2022).We also compare our method on N-Caltech 101dataset.Under the same condition as TCJA-SNN (Zhu et al., 2022) with MSE the loss function, we get a 2.38% increase over it and outperform the comparable result.Finally, we test our algorithm on the N-MNIST dataset.As shown in Table 3, most comparison works get over 99% accuracy.We use the same architecture as PLIF.Our STCA-SNN reaches the best accuracy of 99.67%.

Ablation study 4.3.1. Ablation study
We performed ablation experiments based on the PLIF neuron model to evaluate the effectiveness of the STCA module.For each dataset, we trained three types of SNNs: STCA-SNNs, TA-SNNs with temporal-wise attention module (Yao et al., 2023c), and vanilla SNNs (PLIF-SNN) without any attention module.The SE attention employed by TA-SNNs in the temporal dimension and the Self-attention employed in this work are both non-local operators, thus, we compared the performance of these two classic non-local operators under the same experiment conditions.We followed the learning process described in section 4.1 for all ablation experiments, and the attention locations were identical for both TA-SNNs and STCA-SNNs.Table 4 shows that all STCA-SNNs outperformed vanilla SNNs on three event stream classification datasets, suggesting that the benefits of the STCA module are not limited to a specific dataset or architecture.Furthermore, Figure 3 illustrates the accuracy performance trend of vanilla SNN, TA-SNN, and our proposed STCA-SNN over 1,000 epochs on the N-Caltech101 dataset.As the training epoch increased, our proposed STCA-SNN demonstrated comparable performance with TA-SNN.This indicates that our STCA module can enhance the representation ability of SNNs.

Discuss of pooling operations
To investigate the influence of the avg-pooling and max-pooling operation, we conducted several ablation studies.As is well known, avg-pooling can capture the degree information of target objects,

Conclusion
In this work, we propose the STCA-SNNs to enhance the temporal information processing capabilities of SNNs.The STCA module captures temporal dependencies across channels globally using self-attention, enabling the network to learn 'when' to attend to 'what' .We verified the performance of STCA-SNNs on various neuromorphic datasets across different architectures.The experimental results show that STCA-SNNs achieve competitive accuracy on N-MNIST, CIFAR10-DVS, and N-Caltech 101 datasets.

FIGURE 1
FIGURE 1 Illustration of receptive fields on channel and temporal domains.T means the temporal domain, C means the channel domain, and H, W represent the spatial domain.(A) TCJA-SNN utilizes two local attention mechanisms with 1-D convolution along temporal-wise and channel-wise, respectively, then fuse them, forming a crossshaped receptive field.(B) STCA-SNN uses self-attention operation to establish temporal-wise and channel-wise correlations, forming a global receptive field.
and set the normalization factor as C r

FIGURE 2
FIGURE 2Diagram of the STCA module.The STCA module first aggregates spatial information by average-pooling and max-pooling then merges them and feeds it into a self-attention block to establish the correlations in both temporal and channel dimensions.

(
2022) for training the same dataset.For the N-Caltech 101 dataset, we adopt the same architecture with Zhu et al. (2022) and N-MNIST refers to PLIF Fang et al. (2021).The voting layers are implemented using average pooling for classification robustness.
FIGURE 3Convergence of compared SNN methods on N-Caltech101 dataset.
Table 3 displays the accuracy performance of the proposed STCA-SNNs compared to other competing methods on three neuromorphic datasets, N-MNIST, CIFAR10-DVS, and N-Caltech 101.We mainly include direct training results of SNNs with signal transmission via

TABLE 3
Accuracy performance comparison between the proposed method and the SOTA methods on different datasets.

TABLE 4
Accuracy of vanilla SNN, TA-SNN, and STCA-SNN models on different datasets.