萤火跑模型 | 可变形的 Attention 助力 ViT 优化

Gao Huang    July 11, 2022

Vision Transformer (ViT) 模型在各类视觉任务上都展现出了强⼤的性能。因其具有较⼤甚⾄增⼤到全局的感受野,ViT 相⽐卷积神经⽹络(CNN)能更好地对⻓距离依赖关系建模,特别是在⼤量训练数据的情况下,ViT 可以轻易扩展参数以达得 SOTA 的实验结果。但是,ViT 里的 Attention 机制也是⼀把双刃剑,⼤量的 key/value 增加了不少计算量,使模型难于收敛,也增加了过拟合的⻛险。

最近来自清华黄高课题组的研究者们对 ViT 模型中的 Attention 机制进行改进,提出了可变形的 attention 机制。研究者们让所有 query 都跟同⼀组 key 和 value 交互,通过对每个输⼊图像学习⼀组偏移量,移动 key 和 value 到重要的位置。这种设计不仅增强了 sparse attention 的表征能⼒,同时具有线性空间复杂度。

该项目利用了幻方AI深度学习训练平台的算力及加速性能,在大量的场景下进行了实验对比,验证了所提方法的优异性能。在前不久闭幕的 CVPR 2022 视觉领域顶级学术会议上,该项工作进入了 Best Paper 奖项的候选角逐。

论文标题:Vision Transformer with Deformable Attention

论文地址https://arxiv.org/pdf/2201.00520.pdf

模型仓库https://github.com/LeapLabTHU/DAT

模型介绍

1. 概述

为了避免过量的 attention,现有工作采取了很多的稀疏化 Attention 的办法:

  • Swin Transformer 设计了滑动窗口机制,每次在窗口内部计算 Attention;
  • PVT 将 key/value 进行降采样来节约计算的开销

虽然这些方法十分有效,但是它们手工设计的 Attention 模式容易将与任务相关的 key 和 value 信息丢弃。例如 PVT 降采样会损失特征的细节,而 Swin Transformer 则丢弃了单个窗口之外的内容,不同窗口的交互仅依靠滑动窗口间接实现。

1

理想情况下,Attention 关注的位置应该根据输入变化,key 和 value 应该关注到输入图像中重要的部分,比如目标检测中的目标物体上。在 CNN 的工作中,可变形卷积(DCN)学习偏移的感受野已经表明非常有效。受此启发研究者们尝试探索 Attention 机制中的可变形结构的设计。

然而,直接将 DCN 的机制实现到 ViT 的主干网络中并不是一个简单的问题,因为 Attention 需要计算每个 query 和其对应感受野内所有 key 的相似度,所以这种方案需要缓存每个 query 的 key 和 value。这样一来 Attention 的空间复杂度变为了平方量级,大部分的计算设备都无法接受这样的复杂度。研究者们巧妙地提出让所有 query 都跟同一组 key 和 value 交互,而通过对每个输入图像学习一组偏移量,移动 key 和 value 到重要的位置,最终实现可变形的 attention 机制。如上图 (d) 所示,研究者们提出的可变形 Attention 增强了 Sparse Attention 的表征能力,同时具有线性空间复杂度。

2. 方法

2

如上图所示,可变形注意力(Deformable Attention)首先将输入的特征图 xx 线性映射为 query:q=xWqq=xW_q。接下来,初始化一组均匀网格作为采样点的参考点 pp,偏移量由 query 通过子网络 θoffset\theta_\text{offset} 产生:

Δp=stanh(θoffset(q))\Delta{}p=s\tanh(\theta_\text{offset}(q))

其中 ss 为控制偏移量幅度的超参数。接下来,将偏移量加到参考点上得到采样点坐标,使用双线性插值对原特征图进行采样:

x~=ϕbilinear(x;p+Δp)\tilde{x}=\phi_\text{bilinear}(x;p+\Delta{}p)

其中 ϕbilinear(;)\phi_\text{bilinear}(\cdot;\cdot) 表示双线性采样操作,第一个操作数为特征图,第二个操作数为采样点坐标。具体来说,双线性插值的实现如下图所示。

3

在此基础上,变形后的 key 和 value 分别通过 k~=x~Wk\tilde{k}=\tilde{x}W_kv~=x~Wv\tilde{v}=\tilde{x}W_v 得到。常见于 ViT 中的位置编码则通过采样一张相对位置编码表 B^rpb\hat{B}_\text{rpb} 得到,

B~rpb=ϕbilinear(B^rpb;R(pq,pk~))\tilde{B}_\text{rpb}=\phi_\text{bilinear}(\hat{B}_\text{rpb};R(p_q,p_{\tilde{k}}))

其中 R(pq,pk~)R(p_q,p_{\tilde{k}}) 表示计算 query 和变形后的 key 的相对位置。最后多头注意力中的每个头的输出 z(m)z^{(m)} 如下式计算:

z(m)=softmax(q(m)k~(m)d+B~rpb)v~(m)z^{(m)}=\text{softmax}\left(\frac{q^{(m)}\tilde{k}^{(m)^\top}}{\sqrt{d}}+\tilde{B}_\text{rpb}\right)\tilde{v}^{(m)}

其中 dd 为每个头的特征维度,与常规 Transformer 中相同。

与 DCN 类似,将设计好的可变形注意力模块加入到 ViT 模型的最后两个阶段,得到 Deformable Attention Transformer,如下图所示。

4

实验

研究者们在 ImageNet-1K 图像分类、MS-COCO 目标检测和实例分割、ADE20K 语义分割三个重要的视觉任务上进行实验,数据集皆使用幻方 AI 整理提供的数据集仓库 hfai.datasets,详情可以参考:https://doc.hfai.high-flyer.cn/api/datasets.html

运行以上实验需要大量计算资源。研究者们与幻方 AI 合作,采用一系列幻方 AI 自研的深度学习套件,优化模型提升训练速度,在幻方萤火集群上获得了快速的训练迭代。接下来的几个表格展示了最终的实验结果:

ImageNet-1K 图像分类

5

COCO 目标检测:

6

7

ADE20K 语义分割:

8

可以看到研究者们提出的 DAT 在不同大小的模型上均有很好的表现。

可视化

为了验证 DAT 的有效性,研究者们展示了一些来自 MS-COCO 验证集的样例。

9

如上图所示,图中橙色圆形表示移动后的部分具有较高 Attention Score 的 key,圆形大小表示累积的 Attention 分数大小,该分数越大,圆形越大。该结果表明 DAT 学习关注到了输入图像中的重要部分。

总结

DAT 方法的提出,进一步推动了 ViT 模型的落地实践。研究者们创造性提出的可变形 Attention 方案,极大降低了 ViT 的训练成本,提升模型的效果。同时,研究者采用幻方 AI 的一系列深度学习自研套件,极大加速了模型的训练过程,获得了大量的实验结果。我们欢迎更多优秀的课题与幻方合作,一道推动 AI 技术的发展与落地。

综合体验打分如下:

  1. 研究指数:★★★★

    该模型是 ViT 模型的改进研究,推动了视觉 AI 技术的落地实践。

  2. 开源指数:★★★★★

    数据和代码都进行了完整且规整的开源。

  3. 门槛指数:★★★

    数据规模大,模型适中,适合多级多卡数据并行训练。一般单卡训练难度比较大。

  4. 通用指数:★★★★

    该方法能适用于很多视觉场景。

  5. 适配指数:★★★★★

    依赖简单,很容易与幻方 AI 的训练优化工具结合,提效明显。

幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力 AI 落地与价值创造,欢迎各方数据研究者与开发者们一同共建。


本文作者: Gao Huang


您可以转载、不违背作品原意地摘录及引用本技术博客的内容,但必须遵守以下条款: 署名 — 您应当署名原作者,但不得以任何方式暗示幻方为您背书,亦不会对幻方的权利造成任何负面影响。 非商业性使用 — 您不得将本技术博客内容用于商业目的。 禁止演绎 — 如果基于该内容改编、转换、或者再创作,您不得公开或分发被修改内容,该内容仅可供个人使用。