解构 Visual Transformer
通过对 Attention 公式的变换,理解 Visual Transformer 的各个模块。
本文针对 arXiv:2006.03677 的早期版本,与最新版本会有所出入。
知识背景
Attention (Retrieval)
文中所有的公式都是从下式派生,
Attention(Q,K,V)=softmax(QKT)V,
其中 Q∈RNQ×dK、K∈RNV×dK、V∈RNV×dV、R∈RNQ×dV。
具体解释为:
- Q 即 Query,代表 NQ 个想要检索的目标,每行是一个 dK 维的查询关键字。例如,可以按类型将电影 embed 到一个关键字空间中,每个分量可以代表其具有的不同属性;
- K 即 Key,代表 NV 个被检索的候选者,每行是一个 dK 维的查询关键字。是所有被检索的候选者在关键字空间中的表示;
- V 即 Value,代表 NV 个被检索的候选者,每行是一个 dV 维的值。例如,将电影的内容记录为一个向量,是真正想要取回的内容;
- M=softmax(QKT)∈RNQ×NV 即 Match,根据内积计算相似度并归一化,对于每个检索目标,相似度高的候选者将取得较高的值。
有时 Q、K、V 并不直接存在,而是由已有的值进行变换得到,例如:
Attention(QWQ,KWK,VWV)=softmax((QWQ)(KWK)T)(VWV).
此时,通常 K=V 或者它们的空间维度一致。例如检索电影时,电影的内容和类型是混杂在一起的信息,并具有很多冗余,需要利用 WK、WV 对其进行变换,分别得到类型矩阵和内容矩阵;同时检索的查询关键字也存在冗余,或者并非一个直接的查询,需要 WQ 进行变换。
Multi-Head
本文中没有详细介绍(但是提到使用了该技术),使用该技术可以显著减少 Attention 的计算开销。
MultiHead(Q,K,V)=concat(headi)WO,
其中,headi=Attention(QWQi,KWKi,VWVi),i=1,2,…,h。这将大型矩阵乘法划分在 h 个子空间计算,最后使用 WO 进行组合。
视觉模块
Filter-Based Tokenizer (Static Tokenizer)
「利用静态的 Queries 从像素特征图检索 Tokens 信息」
由下式开始变换,
R=softmax(QKT)V,
令 K=XˉWK、V=XˉWV 且 Xˉ∈RHW×C 是由原特征图将空间维度 flatten 得到,XK、XV 分别是从像素特征空间变换为关键字空间和值空间(也即 Token 的特征空间)的线性变换。意为从原特征图中利用静态的查询 Q 检索信息,有
T=softmax(Q(XˉWK)T)(XˉWV)=softmax((QWKT)XˉT)(XˉWV).
若取 dK=C、dV=CT,并记 QWKT=WA 得原文公式,
T=softmax(WAXˉT)(XˉWV)=ATV,
可学习的参数为 WA、WV。
Recurrent Tokenizer (Dynamic Tokenizer)
「利用已有的 Tokens 得到 Queries 从像素特征图检索新的 Tokens 信息」
是 Filter-based Tokenizer 的变种,此时使用的 Q 并非静态的,而是已有 Tokens 进行变换得到,即 Q=TinWQ,有
Tout=softmax((TinWQ)(XˉinWK)T)(XˉinWV)=softmax((TinWQWKT)XˉinT)(XˉinWV),
记 WQWKT=WT→WA 得
Tout=softmax((TinWT→WA)XˉinT)(XˉinWV)=StaticTokenizer(TinWT→WA,Xˉin),
可学习的参数为 WT→WA、WV,原文形式是只更新一半的 Tokens。
Position Encoder
「利用 Tokenizer 的匹配结果,从静态的 Position Encodings 中检索位置信息」
保持 Tokenizer 的 Match 不变,令 V=WA→P∈RHˉWˉ×CP,是从位置编码 WA→P 中按权重取回值(我理解是为了降维或者固定位置编码的维数的考虑,进行了下采样),得
P=downsample(A)TWA→P,
可学习的参数为 WA→P。
「Tokens 进行自源的信息检索」
对于
R=softmax((QWQ)(KWK)T)(VWV),
令 Q=K=V=Tin,改记 WQ=Q、WK=K、WV=V,并一般取 dK=CT/2 有
Tout=softmax((TinQ)(TinK)T)(TinV).
再加上残差结构,即
Tout=Tin+softmax((TinQ)(TinK)T)(TinV),
可学习的参数为 Q、K、V。
Projector
「利用像素特征图从 Tokens 中检索信息,从而对特种图进行了语义的增强」
对于
R=softmax((QWQ)(KWK)T)(VWV),
令 Q=Xˉ、K=V=Tout 且 WQ=QT→X、WK=KT→X、WV=VT→V,并一般取 dK=CT/2(原文这三个变换矩阵的维度有误),有
Xˉout=softmax((XˉinQT→X)(ToutKT→X)T)(ToutVT→X).
同样地再加上残差结构,即
Xout=Xin+softmax((XˉinQT→X)(ToutKT→X)T)(ToutVT→X),
可学习的参数为 QT→X、KT→X、VT→X。
模块总结
Name | Query | Key | Value | Result |
---|
Static Tokenizer (ST) | Static | Feature Map | Feature Map | Tokens |
Dynamic Tokenizer (DT) | Tokens | Feature Map | Feature Map | Refined Tokens |
Position Encoder (PE) | Static / Tokens | Tokens | Static | Position Encodings |
Transformer (Tr) | Tokens | Tokens | Tokens | Transformed Tokens |
Projector (Pr) | Feature Map | Tokens | Tokens | Refined Feature Map |
模型构建
- Classification: 使用 ST、Tr、DT、Tr、……、DT、Tr 结构代替了 ResNet 中的 Stage 5。
- Semantic Segmentation: 使用 ST、Tr、Pr 结构代替了 FPN 中的 Lateral Conv 和 Downsample、Down Conv。(依据原文描述,原文的示意图可能有误,标出了 downsample)