@ Ruijun Gao # Computer Vision# Transformer

解构 Visual Transformer

通过对 Attention 公式的变换,理解 Visual Transformer 的各个模块。

本文针对 arXiv:2006.03677 的早期版本,与最新版本会有所出入。

知识背景

Attention (Retrieval)

文中所有的公式都是从下式派生,

Attention(Q,K,V)=softmax(QKT)V,\mathrm{Attention}\left(Q,K,V\right)=\mathrm{softmax}\left(QK^\mathrm{T}\right)V,

其中 QRNQ×dKQ\in\mathbb{R}^{N_Q\times d_K}KRNV×dKK\in\mathbb{R}^{N_V\times d_K}VRNV×dVV\in\mathbb{R}^{N_V\times d_V}RRNQ×dVR\in\mathbb{R}^{N_Q\times d_V}

具体解释为:

有时 QQKKVV 并不直接存在,而是由已有的值进行变换得到,例如:

Attention(QWQ,KWK,VWV)=softmax((QWQ)(KWK)T)(VWV).\mathrm{Attention}\left(QW_Q,KW_K,VW_V\right)=\mathrm{softmax}\left(\left(QW_Q\right){\left(KW_K\right)}^\mathrm{T}\right)\left(VW_V\right).

此时,通常 K=VK=V 或者它们的空间维度一致。例如检索电影时,电影的内容和类型是混杂在一起的信息,并具有很多冗余,需要利用 WKW_KWVW_V 对其进行变换,分别得到类型矩阵和内容矩阵;同时检索的查询关键字也存在冗余,或者并非一个直接的查询,需要 WQW_Q 进行变换。

Multi-Head

本文中没有详细介绍(但是提到使用了该技术),使用该技术可以显著减少 Attention 的计算开销。

MultiHead(Q,K,V)=concat(headi)WO,\mathrm{MultiHead}\left(Q,K,V\right)=\mathrm{concat}\left(\mathrm{head}_i\right)W_O,

其中,headi=Attention(QWQi,KWKi,VWVi)\mathrm{head}_i=\mathrm{Attention}\left(Q{W_Q}_i,K{W_K}_i,V{W_V}_i\right)i=1,2,,hi=1,2,\dots,h。这将大型矩阵乘法划分在 hh 个子空间计算,最后使用 WOW_O 进行组合。

视觉模块

Filter-Based Tokenizer (Static Tokenizer)

「利用静态的 Queries 从像素特征图检索 Tokens 信息」

由下式开始变换,

R=softmax(QKT)V,R=\mathrm{softmax}\left(QK^\mathrm{T}\right)V,

K=XˉWKK=\bar{X}W_KV=XˉWVV=\bar{X}W_VXˉRHW×C\bar{X}\in\mathbb{R}^{HW\times C} 是由原特征图将空间维度 flatten 得到,XKX_KXVX_V 分别是从像素特征空间变换为关键字空间和值空间(也即 Token 的特征空间)的线性变换。意为从原特征图中利用静态的查询 QQ 检索信息,有

T=softmax(Q(XˉWK)T)(XˉWV)=softmax((QWKT)XˉT)(XˉWV).\begin{aligned} T &=\mathrm{softmax}\left(Q{\left(\bar{X}W_K\right)}^\mathrm{T}\right)\left(\bar{X}W_V\right)\\ &=\mathrm{softmax}\left(\left(QW_K^\mathrm{T}\right)\bar{X}^\mathrm{T}\right)\left(\bar{X}W_V\right). \end{aligned}

若取 dK=Cd_K=CdV=CTd_V=C_T,并记 QWKT=WAQW_K^\mathrm{T}=W_A 得原文公式,

T=softmax(WAXˉT)(XˉWV)=ATV,T=\mathrm{softmax}\left(W_A\bar{X}^\mathrm{T}\right)\left(\bar{X}W_V\right)=A^\mathrm{T}V,

可学习的参数为 WAW_{A}WVW_V

Recurrent Tokenizer (Dynamic Tokenizer)

「利用已有的 Tokens 得到 Queries 从像素特征图检索新的 Tokens 信息」

是 Filter-based Tokenizer 的变种,此时使用的 QQ 并非静态的,而是已有 Tokens 进行变换得到,即 Q=TinWQQ=T_\mathrm{in}W_{Q},有

Tout=softmax((TinWQ)(XˉinWK)T)(XˉinWV)=softmax((TinWQWKT)XˉinT)(XˉinWV),\begin{aligned} T_\mathrm{out} &=\mathrm{softmax}\left(\left(T_\mathrm{in}W_Q\right){\left(\bar{X}_\mathrm{in}W_K\right)}^\mathrm{T}\right)\left(\bar{X}_\mathrm{in}W_V\right)\\ &=\mathrm{softmax}\left(\left(T_\mathrm{in}W_QW_K^\mathrm{T}\right)\bar{X}_\mathrm{in}^\mathrm{T}\right)\left(\bar{X}_\mathrm{in}W_V\right), \end{aligned}

WQWKT=WTWAW_QW_K^\mathrm{T}=W_{T\to W_A}

Tout=softmax((TinWTWA)XˉinT)(XˉinWV)=StaticTokenizer(TinWTWA,Xˉin),\begin{aligned} T_\mathrm{out} &=\mathrm{softmax}\left(\left(T_\mathrm{in}W_{T\to W_A}\right)\bar{X}_\mathrm{in}^\mathrm{T}\right)\left(\bar{X}_\mathrm{in}W_V\right)\\ &=\mathrm{StaticTokenizer}\left(T_\mathrm{in}W_{T\to W_A},\bar{X}_\mathrm{in}\right), \end{aligned}

可学习的参数为 WTWAW_{T\to W_A}WVW_V,原文形式是只更新一半的 Tokens。

Position Encoder

​「利用 Tokenizer 的匹配结果,从静态的 Position Encodings 中检索位置信息」

保持 Tokenizer 的 Match 不变,令 V=WAPRHˉWˉ×CPV=W_{A\to P}\in\mathbb{R}^{\bar{H}\bar{W}\times C_P},是从位置编码 WAPW_{A\to P} 中按权重取回值(我理解是为了降维或者固定位置编码的维数的考虑,进行了下采样),得

P=downsample(A)TWAP,P=\mathrm{downsample}{\left(A\right)}^\mathrm{T}W_{A\to P},

可学习的参数为 WAPW_{A\to P}

Transformer (Self-Attention)

「Tokens 进行自源的信息检索」

对于

R=softmax((QWQ)(KWK)T)(VWV),R=\mathrm{softmax}\left(\left(QW_Q\right){\left(KW_K\right)}^\mathrm{T}\right)\left(VW_V\right),

Q=K=V=TinQ=K=V=T_\mathrm{in},改记 WQ=QW_Q=QWK=KW_K=KWV=VW_V=V,并一般取 dK=CT/2d_K=C_T/2

Tout=softmax((TinQ)(TinK)T)(TinV).T_\mathrm{out}=\mathrm{softmax}\left(\left(T_\mathrm{in}Q\right){\left(T_\mathrm{in}K\right)}^\mathrm{T}\right)\left(T_\mathrm{in}V\right).

再加上残差结构,即

Tout=Tin+softmax((TinQ)(TinK)T)(TinV),T_\mathrm{out}=T_\mathrm{in}+\mathrm{softmax}\left(\left(T_\mathrm{in}Q\right){\left(T_\mathrm{in}K\right)}^\mathrm{T}\right)\left(T_\mathrm{in}V\right),

可学习的参数为 QQKKVV

Projector

「利用像素特征图从 Tokens 中检索信息,从而对特种图进行了语义的增强」

对于

R=softmax((QWQ)(KWK)T)(VWV),R=\mathrm{softmax}\left(\left(QW_Q\right){\left(KW_K\right)}^\mathrm{T}\right)\left(VW_V\right),

Q=XˉQ=\bar{X}K=V=ToutK=V=T_\mathrm{out}WQ=QTXW_Q=Q_{T\to X}WK=KTXW_K=K_{T\to X}WV=VTVW_V=V_{T\to V},并一般取 dK=CT/2d_K=C_T/2(原文这三个变换矩阵的维度有误),有

Xˉout=softmax((XˉinQTX)(ToutKTX)T)(ToutVTX).\bar{X}_\mathrm{out}=\mathrm{softmax}\left(\left(\bar{X}_\mathrm{in}Q_{T\to X}\right) {\left(T_\mathrm{out}K_{T\to X}\right)}^\mathrm{T}\right)\left(T_\mathrm{out}V_{T\to X}\right).

同样地再加上残差结构,即

Xout=Xin+softmax((XˉinQTX)(ToutKTX)T)(ToutVTX),X_\mathrm{out}=X_\mathrm{in}+\mathrm{softmax}\left(\left(\bar{X}_\mathrm{in}Q_{T\to X}\right) {\left(T_\mathrm{out}K_{T\to X}\right)}^\mathrm{T}\right)\left(T_\mathrm{out}V_{T\to X}\right),

可学习的参数为 QTXQ_{T\to X}KTXK_{T\to X}VTXV_{T\to X}

模块总结

NameQueryKeyValueResult
Static Tokenizer (ST)StaticFeature MapFeature MapTokens
Dynamic Tokenizer (DT)TokensFeature MapFeature MapRefined Tokens
Position Encoder (PE)Static / TokensTokensStaticPosition Encodings
Transformer (Tr)TokensTokensTokensTransformed Tokens
Projector (Pr)Feature MapTokensTokensRefined Feature Map

模型构建

  1. Classification: 使用 ST、Tr、DT、Tr、……、DT、Tr 结构代替了 ResNet 中的 Stage 5。
  2. Semantic Segmentation: 使用 ST、Tr、Pr 结构代替了 FPN 中的 Lateral Conv 和 Downsample、Down Conv。(依据原文描述,原文的示意图可能有误,标出了 downsample)