C-Mamba: Channel Correlation Enhanced State Space Models for Multivariate Time Series Forecasting
Contents
- Abstract
 - Introduction
 - Preliminray
    
- MTS forecasting
 - Mamba
 
 - Methodology
    
- Channel Mixup
 - C-Mamba Block
 
 - Experiments
 
0. Abstract
Limitations of previous models
- Linear: limit in capacities
 - Attention: quadratic complexity
 - CNN: restricted receptive field
 
CI strategy: ignoring their correlations!!
CD strategy: considering inter-channel relationships
- 
    
Via self-attention mechanism, linear combination, or convolution …
\(\rightarrow\) high computational costs
 
C-Mamba
- 
    
(1) SSM that captures (2) cross-channel dependencies ,
while maintaining (3) linear complexity (4)without losing the global receptive field
 - 
    
Components
- (1) Channel mixup
        
- two channels are mixed to enhance the training sets
 
 - (2) Channel attention enhanced patch-wise Mamba encoder
        
- leverages the ability of the SSM
 - capture cross-time dependencies and models correlations between channels by mining their weight relationships.
 
 
 - (1) Channel mixup
        
 
1. Introduction
Cross-channel dependencies are also vital for MTSF!
Ex) two variables over time in the ETT
Observations:
- 
    
(1) Two variables exhibit strong temporal characteristics similarity
 - 
    
(2) Show a strong proportional relationship
- 
        
MULL (Middle UseLess Load) = 1/2 x HULL (High UseLess Load)
\(\rightarrow\) Necessity of modeling cross-channel dependencies from proportional relationships
 
 - 
        
 

C-Mamba
- 
    
To better capture cross-time and cross-channel dependencies
 - 
    
Channel enhanced SSM
 - 
    
Problem & Solution [1]
- 
        
Problem) Oversmoothing caused by the CD
 - 
        
Solution) Channel mixup strategy
- 
            
inspired by mixup data augmentation
 - 
            
fuses two channels via a linear combination for training
( = generate virtual channel )
 - 
            
Generate virtual channel integrate characteristics from different channels while retaining their shared cross-time dependencies, which is expected to improve the generalizability of models.
 
 - 
            
 
 - 
        
 - 
    
Problem & Solution [2]
- Problem) Capturing both cross-time and cross-channel dependencies
 - Solution) Channel attention enhanced patch-wise Mamba encoder
        
- a) Cross-time dependencies: via Mamba ( + Patching )
            
- Patch-wise Mamba module
 
 - b) Cross-channel dependencies: via channel attention
            
- Lightweight mechanism that considers various relationships between channels
                
- (1) weighted summation relationships & (2) proportional relationships
 
 
 - Lightweight mechanism that considers various relationships between channels
                
 
 - a) Cross-time dependencies: via Mamba ( + Patching )
            
 
 
2. Preliminary
(1) MTS forecasting
Notation
- \(\mathbf{X}=\left\{\mathbf{x}_1, \ldots, \mathbf{x}_L\right\} \in \mathbb{R}^{L \times V}\) ,
    
- \(\mathbf{X}_{t,:}\) : value of all channels at time step \(t\),
 - \(\mathbf{X}_{:, v}\) as the entire sequence of the channel indexed by \(v\)
 
 - \(\mathbf{Y}=\left\{\mathbf{x}_{L+1}, \ldots, \mathbf{x}_{L+T}\right\} \in \mathbb{R}^{T \times V}\).
 
(2) Mamba
Notation
- Input \(\mathbf{x}(t) \in \mathbb{R}\),
 - Output \(\mathbf{y}(t) \in \mathbb{R}\),
 - Hidden state \(\mathbf{h}(t) \in \mathbb{R}^N\)
 
\(\begin{aligned} \mathbf{h}^{\prime}(t) & =\mathbf{A h}(t)+\mathbf{B x}(t), \\ \mathbf{y}(t) & =\mathbf{C h}(t), \end{aligned}\).
- \(\mathbf{A} \in \mathbb{R}^{N \times N}\) : state transition matrix,
 - \(\mathbf{B} \in \mathbb{R}^{N \times 1}\) and \(\mathbf{C} \in \mathbb{R}^{1 \times N}\) : projection matrices
 
Multivariate TS
- \(\mathbf{x}(t) \in \mathbb{R}^V\) and \(\mathbf{y}(t) \in \mathbb{R}^V\),
 - \(\mathbf{A} \in \mathbb{R}^{V \times N \times N}, \mathbf{B} \in \mathbb{R}^{V \times N}\),  \(\mathbf{C} \in \mathbb{R}^{V \times N}\).
    
- \(\mathbf{A}\) can be compressed to \(V \times N\)
 
 
Discreticize
\(\begin{aligned} \overline{\mathbf{A}} & =\exp (\Delta \mathbf{A}) \\ \overline{\mathbf{B}} & =(\Delta \mathbf{A})^{-1}(\exp (\Delta \mathbf{A})-\mathbf{I}) \Delta \mathbf{B} \\ \mathbf{h}_t & =\overline{\mathbf{A}} \mathbf{h}_{t-1}+\overline{\mathbf{B}} \mathbf{x}_t \\ \mathbf{y}_t & =\mathbf{C h}_t \end{aligned}\).
- where \(\Delta \in \mathbb{R}^V\) is the sampling time interval
 
Global convolution
\(\begin{aligned} & \overline{\mathbf{K}}=\left(\mathbf{C} \overline{\mathbf{B}}, \mathbf{C A B}, \ldots, \mathbf{C} \overline{\mathbf{A}}^{L-1} \overline{\mathbf{B}}\right) \\ & \mathbf{Y}=\mathbf{X} * \overline{\mathbf{K}}, \end{aligned}\).
- where \(L\) is the length of the sequence.
 
Selective scan mechanism
Selective scan strategy ( Data-dependent mechanism )
- \(\mathbf{B} \in \mathbb{R}^{L \times V \times N}, \mathbf{C} \in \mathbb{R}^{L \times V \times N}\),
 - \(\Delta \in \mathbb{R}^{L \times V}\) : derived from the input \(\mathbf{X} \in \mathbb{R}^{L \times V}\).
 
3. Methodology

Before training
- Channel mixup module: Mixes MTS in channel dim
 
Model
- C-Mamba block
    
- Vanilla Mamba module
 - Channel attention module
 
 - Exploits both cross-time and cross-channel dependencies.
 - Patch-wise sequences
 
(1) Channel Mixup
Mixup
\(\begin{aligned} & \tilde{x}=\lambda x_i+(1-\lambda) x_j \\ & \tilde{y}=\lambda y_i+(1-\lambda) y_j \end{aligned}\).
- where \((\tilde{x}, \tilde{y})\) is the synthesized virtual sample, and \(\lambda \in[0,1]\).
 
Channel Mixup
\(\begin{aligned} & \mathbf{X}^{\prime}=\mathbf{X}_{:, i}+\lambda \mathbf{X}_{:, j}, i, j=0, \ldots, V-1 \\ & \mathbf{Y}^{\prime}=\mathbf{Y}_{:, i}+\lambda \mathbf{Y}_{:, j}, i, j=0, \ldots, V-1 \end{aligned}\).
- where \(\mathbf{X}^{\prime} \in \mathbb{R}^{L \times 1}\) and \(\mathbf{Y}^{\prime} \in \mathbb{R}^{T \times 1}\) are hybrid channels
 - where randperm \((V)\) generates a randomly arranged array of \(0 \sim V-1\)
 - \(\lambda \sim N\left(0, \sigma^2\right)\) is the linear combination coefficient
 
Normal distribution with a mean of 0
\(\rightarrow\) Ensuring that the overall characteristics of each channel remain unchanged.

(2) C-Mamba Block
Consists of two key components:
- (1) Patch-wise Mamba module
 - (2) Channel attention module
 
\(\rightarrow\) Capture cross-time and cross-channel dependencies respectively!
a) PatchMamba
Patching
- Each univariate TS \(\mathbf{X}_{: v} \in \mathbb{R}^L\),
 - \(\hat{\mathbf{X}}_{: v}=\operatorname{Patching}\left(\mathbf{X}_{: v}\right) \in \mathbb{R}^{N \times P}\).
 
b) Channel Attention
Structure of the channel attention module.
Notation:
- 
    
\(\mathbf{H}_l \in \mathbb{R}^{V \times N \times D}\): Embedding after the \(l^{t h}\) PatchMamba module
 - 
    
\(\operatorname{Att}_l=\operatorname{sigmoid}\left(\operatorname{MLP}\left(\operatorname{MaxPool}\left(\mathbf{H}_l\right)\right)+\operatorname{MLP}\left(\operatorname{AvgPool}\left(\mathbf{H}_l\right)\right)\right)\).
- AvgPool and MaxPool : applied to the last two dimensions
 
 - 
    
\(\operatorname{Att}_l=\operatorname{sigmoid}\left(\mathbf{W}_1\left(\operatorname{Gelu}\left(\mathbf{W}_0 \mathbf{F}_{\text {max }}^l\right)\right)+\mathbf{W}_1\left(\operatorname{Gelu}\left(\mathbf{W}_0 \mathbf{F}_{\text {avg }}^l\right)\right)\right)\).
- 
        
\(\mathbf{F}_{\text {max }}^l \in \mathbb{R}^{V \times 1 \times 1}\) and \(\mathbf{F}_{\text {avg }}^l \in \mathbb{R}^{V \times 1 \times 1}\) ,
 - 
        
\(\mathbf{W}_0 \in \mathbb{R}^{V / r \times V}\) and \(\mathbf{W}_1 \in \mathbb{R}^{V \times V / r}\),
- 
            
\(r\): controlling the parameter complexity, denotes the reduction ratio.
( = Essential for time series with hundreds of channels, tune it in \(\{2,4,8\}\). )
 
 - 
            
 
 - 
        
 - 
    
\(\operatorname{Att}_l \in \mathbb{R}^{V \times 1 \times 1}\):
 
\(\mathbf{C A}_l=\operatorname{Att}_l \odot \mathbf{H}_l .\) : output of the channel attention module
c) Overall Pipeline
(Instance normalization to mitigate the distribution shifts)
- 
    
\(\mathbf{X}^{\prime}, \mathbf{Y}^{\prime}=\operatorname{Mixup}(\mathbf{X}, \mathbf{Y})\).
 - 
    
\(\mathbf{X}_{\text {norm }}^{\prime}=\operatorname{InstanceNorm}\left(\mathbf{X}^{\prime}\right)\).
 - 
    
\(\hat{\mathbf{X}} =\operatorname{Patching}\left(\mathbf{X}_{\text {norm }}^{\prime}\right)\).
 - 
    
\(\mathbf{Z}_0 =\hat{\mathbf{X}} \mathbf{W}_p+\mathbf{W}_{\text {pos }}\).
- 
        
learnable position encoding \(\mathbf{W}_{\text {pos }}\).
 - 
        
where \(\hat{\mathbf{X}} \in \mathbb{R}^{V \times N \times P}, \mathbf{W}_p \in \mathbb{R}^{P \times D}, \mathbf{W}_{\text {pos }} \in \mathbb{R}^{N \times D}\), and \(\mathbf{Z}_0 \in \mathbb{R}^{V \times N \times D}\)
 
 - 
        
 
[C-mamba encoder .. k blocks]
- 
    
\(\begin{aligned} \mathbf{H}_l & =\operatorname{PatchMamba}\left(\mathbf{Z}_{l-1}\right), \\ \mathbf{Z}_l & =\operatorname{Att}_l\left(\mathbf{H}_l\right) \odot \mathbf{H}_l+\mathbf{Z}_{l-1} \end{aligned}\).
 - 
    
\(\hat{\mathbf{Y}}_p=\operatorname{Flatten}\left(\operatorname{Silu}\left(\operatorname{RMS}\left(\mathbf{Z}_k\right)\right)\right) \mathbf{W}_{p r o j}\).
- linear projection layer: \(\mathbf{W}_{p r o j} \in \mathbf{R}^{(N * D) \times T}\)
 - \(\hat{\mathbf{Y}}_p \in \mathbb{R}^{V \times T}\).
 
 
4. Experiments



