在深度学习中,注意力机制已成为提升模型性能的重要手段。其中,多尺度注意力模块(Multi-Scale Attention Module) 通过整合不同尺度的特征表示,显著增强了模型对复杂场景的感知能力。然而,现有方法如 EMA(Efficient Multi-Scale Attention) 虽然在图像分类和目标检测任务中表现优异,但其固定分组策略、单一空间注意力机制以及缺乏全局建模能力,限制了其进一步发展。
本文提出了一种新的注意力模块——EMAX(Enhanced Multi-scale Attention with eXpressive learning),在保持 EMA 高效性的同时,引入了动态分组、通道注意力、门控残差连接和全局注意力增强等创新设计,使其具备更强的表达能力和泛化能力。


一、背景与动机

1.1 多尺度注意力机制概述

多尺度注意力的核心思想是通过构建不同尺度的特征表示,使模型能够同时关注局部细节与全局结构。例如:

  • Swin Transformer 使用滑动窗口机制进行局部到全局的信息融合;
  • EMA 模块 通过将通道分组并结合空间注意力,实现了高效的多尺度建模;
  • CBAM 引入通道与空间注意力的串行组合,提升特征选择能力。
    尽管这些方法各具优势,但在以下方面仍存在不足:
  • 分组策略固定,无法适应输入特征分布的变化;
  • 注意力机制局限于单一维度(如仅空间或仅通道);
  • 缺乏对长距离依赖的有效建模;
  • 可解释性较弱,难以可视化分析注意力权重。
    因此,我们提出了 EMAX,以尽力解决上述问题。

二、EMAX 模块的设计与实现

2.1 核心组件介绍

EMAX 包含以下关键组成部分:

组件功能
动态分组(Dynamic Grouping)根据输入特征统计信息自适应调整分组数量
空间注意力分支(Spatial Attention Branch)建模高度与宽度方向的注意力权重
通道注意力分支(Channel Attention Branch)提取通道维度的重要性权重
门控残差连接(Gated Residual Connection)控制注意力输出与原始输入的融合比例
全局注意力增强(Global Attention Enhancement)引入长距离依赖建模

2.2 PyTorch 实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F

class EMAX(nn.Module):
def __init__(self, channels, c2=None, factor=32, reduction=16):
super(EMAX, self).__init__()
self.groups = factor
self.channels = channels
assert channels // self.groups > 0

# Spatial attention branch
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

# Channel attention branch
self.channel_gate = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0),
nn.Sigmoid()
)

# Input Feature Distribution Adaptive Dynamic Grouping Controller
self.group_controller = nn.Linear(channels, self.groups)

# Gating residual connection
self.gamma = nn.Parameter(torch.zeros(1))

# Global attention branch
self.global_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, kernel_size=1),
nn.ReLU(True),
nn.Conv2d(channels // reduction, channels, kernel_size=1),
nn.Sigmoid()
)

def forward(self, x):
b, c, h, w = x.size()

# Dynamic grouping
group_logits = self.agp(x).view(b, c) # [b, c]
group_weights = self.group_controller(group_logits).softmax(dim=-1) # [b, groups]
group_indices = torch.multinomial(group_weights, 1).squeeze() # [b]
groups = torch.clamp(group_indices, min=1, max=self.groups).mode().values.item()
group_x = x.reshape(b * groups, -1, h, w)

# Spatial attention
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
x2 = self.conv3x3(group_x)

x11 = self.softmax(self.agp(x1).reshape(b * groups, -1, 1).permute(0, 2, 1))
x12 = x2.reshape(b * groups, c // groups, -1)
x21 = self.softmax(self.agp(x2).reshape(b * groups, -1, 1).permute(0, 2, 1))
x22 = x1.reshape(b * groups, c // groups, -1)
weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * groups, 1, h, w)
spatial_out = (group_x * weights.sigmoid()).reshape(b, c, h, w)

# Channel attention
channel_weights = self.channel_gate(x)
channel_out = x * channel_weights

# Fusing spatial information with channel attention
fused = spatial_out + self.gamma * channel_out

# Global attention enhancement
global_weights = self.global_attn(fused)
out = fused * global_weights

return out

三、数学推导与理论分析

3.1 输入张量定义

设输入特征图 XRb×c×h×wX \in \mathbb{R}^{b \times c \times h \times w},其中:

  • bb: batch size
  • cc: channel 数量
  • h,wh, w: 特征图高宽

3.2 动态分组机制

引入一个轻量级控制器 fθ:RcRGmaxf_{\theta}:\mathbb{R}^c \to \mathbb{R}^{G_{max}} (其中 GmaxG_{max} 是组的最大数量,例如因子数)处理全局平均池化输入动态选择分组数 GG'

p(G)=Softmax(fθ(AGP(X)))p(G) = \text{Softmax}(f_\theta(\text{AGP}(X)))

群组 GG' 的数量是基于 P(G) 选定的(例如,通过采样或 argmax,然后可能是在批次中取众数):

G=argmax(p(G))G' = \arg\max(p(G))

输入重塑:

Xgroup=Reshape(X,(bG,c/G,h,w))X_{\text{group}} = \text{Reshape}(X, (b \cdot G', c / G', h, w))

3.3 空间注意力机制

高度池化与宽度池化分别提取行/列方向的统计信息:

Xh=AdaptiveAvgPool2D(Xgroup,(None,1))Xw=AdaptiveAvgPool2D(Xgroup,(1,None)).permute(0,1,3,2)X_h = \text{AdaptiveAvgPool2D}(X_{\text{group}}, (None, 1)) \\ X_w = \text{AdaptiveAvgPool2D}(X_{\text{group}}, (1, None)).\text{permute}(0, 1, 3, 2)

合并后使用卷积融合,并拆分为原始尺寸:

HW=Conv1×1([Concat(Xh,Xw)])Xh,Xw=Split(HW,[h,w])HW = \text{Conv1×1}([\text{Concat}(X_h, X_w)]) \\ X_h', X_w' = \text{Split}(HW, [h, w])

注意力加权输出为:

As=Xgroupσ(Xh)σ(Xw.permute(0,1,3,2))A_s = X_{\text{group}} \cdot \sigma(X_h') \cdot \sigma(X_w'.\text{permute}(0, 1, 3, 2))

3.4 通道注意力机制

使用 SE-like 结构产生通道注意力权重:

wc=Sigmoid(Conv1×1(ReLU(Conv1×1(AGP(X)))))w_c = \text{Sigmoid}\left( \text{Conv}_{1×1}\left( \text{ReLU}\left( \text{Conv}_{1×1}(\text{AGP}(X)) \right) \right) \right)

输出为:

Yc=XwcY_c = X \odot w_c

3.5 门控残差连接

引入可学习参数 γR\gamma \in \mathbb{R} 控制融合比例:

Yfusion=YEMA+γYcY_{\text{fusion}} = Y_{\text{EMA}} + \gamma \cdot Y_c

3.6 全局注意力增强

提取全局上下文:

wg=Sigmoid(Conv1×1(ReLU(Conv1×1(AGP(Yfusion)))))w_g = \text{Sigmoid}\left( \text{Conv}_{1×1}\left( \text{ReLU}\left( \text{Conv}_{1×1}(\text{AGP}(Y_{\text{fusion}})) \right) \right) \right)

最终输出为:

Yfinal=YfusionwgY_{\text{final}} = Y_{\text{fusion}} \odot w_g

四、实验与可视化支持

4.1 注意力权重提取

forward 函数中添加钩子函数,提取中间变量:

1
2
3
4
5
6
7
def forward(self, x):
# ...前向传播...
# 记录注意力权重
self.spatial_weights = weights.sigmoid().detach().cpu()
self.channel_weights = channel_weights.detach().cpu()
self.global_weights = global_weights.detach().cpu()
return out

4.2 可视化注意力图

使用 matplotlibseaborn 对注意力权重进行热力图绘制:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(model, input_tensor):
model.eval()
_ = model(input_tensor)
spatial_weights = model.emax.spatial_weights[0, 0].numpy()
channel_weights = model.emax.channel_weights[0].squeeze().numpy()
global_weights = model.emax.global_weights[0].squeeze().numpy()
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
sns.heatmap(spatial_weights, cmap='viridis', ax=axes[0], cbar=True)
axes[0].set_title('Spatial Attention Map')
sns.barplot(y=np.arange(len(channel_weights)), x=channel_weights, ax=axes[1], palette="Blues_d")
axes[1].set_title('Channel Attention Weights')
axes[1].set_xlabel('Attention Weight')
axes[1].set_ylabel('Channel Index')
sns.barplot(y=np.arange(len(global_weights)), x=global_weights, ax=axes[2], palette="Greens_d")
axes[2].set_title('Global Attention Enhancement')
axes[2].set_xlabel('Weight Value')
axes[2].set_ylabel('Channel Index')
plt.tight_layout()
plt.show()

五、对比分析与未来展望

方法多尺度建模动态分组注意力维度全局建模可解释性计算效率
EMA空间
CBAM空间/通道
SE Block通道
MSA多头
EMAX空间/通道/全局

5.1 优势总结

  • 动态分组机制 提升了模型对输入特征分布变化的适应能力;
  • 三重注意力机制 增强了模型对局部与全局信息的建模能力;
  • 轻量化设计 保证了计算效率;
  • 模块化结构 支持即插即用,便于部署。

5.2 局限与改进方向

  • 当前版本适用于图像任务,未来可扩展至视频、NLP 等领域;
  • 动态分组可能带来训练初期不稳定,建议采用 warm-up 策略;
  • 参数调优较为复杂,需结合自动化工具优化配置。