# Vision Transformer
# 关于ViT
Transformer自2017年06月由谷歌团队在论文Attention Is All You Need中提出后,给自然语言处理领域带去了深远的影响,其并行化处理不定长序列的能力及自注意力机制表现亮眼。根据以往的惯例,一个新的机器学习方法往往先在NLP领域带来突破,然后逐渐被应用到计算机视觉领域。时间来到2020年10月,同样是谷歌团队提出了将Transformer应用到视觉任务的方法,Vision Transformer(ViT)。
论文AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (opens new window)
关于对Transformer的介绍可以参考Transformer 介绍
(opens new window)。
将Transformer应用于视觉任务的一种想法是将图像每个像素都flatten,得到一个表示图像的序列,作为模型的输入。但对使用自注意力模块的transformer来说,这种方法随着图像分辨率的变大,计算复杂度也变得很高,因为scaled dot self attention计算时640*640的图像,序列长度transformer所能处理的序列长度。
在ViT中,作者是将输入图像等分成大小为16X16的patch,然后通过image embedding将输入从NCHW转换成(N, hidden_dim, (n_h * n_w)), n_h和n_w是H//patch_size和W//patch_size的大小,flatten后得到长度为
Image Embedding后得到的结果shape为[N, n_h*n_w, hidden_dim],作者将ViT用于分类任务,同BERT的思路,会在输入序列前插入一个cls_token用来输出每个图像所属的类别。处理后,输入给encoder的张量shape为[N, n_h*n_w+1, hidden_dim]。

ViT编码器的输出是的shape为:(N, L, hidden_dim), L是序列的长度,在这里为197=14*14+1,得到编码器的输出后,只取序列的首元素,shape为(N, hidden_dim)作为分类器的输入。从这里会发现,这种方式舍弃了编码器处理得到的大部分信息,只使用了cls_token部分。
这里理解,cls_token有点像完形填空中的单词补全,这里做图像图类,待补充的词元组是类别,而这也正是我们关心的部分。至于编码器提取的序列其他信息,因为没有使用就直接舍弃了。假如说,拿编码器输出序列除cls_token外的部分,再接一个分类器,整个分割效果会不会更好呢?
# 代码分析
输入数据x的shape, NCHW以(1,3,224,224)为例,
- Image Embedding
处理输入数据,NCHW变成(N, (n_h * n_w), hidden_dim)的张量,
n_h/n_w是除以patch_size后得到的图像的大小。
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
n_h = h // p
n_w = w // p
x = self.conv_proj(x) # 卷积层,将`NCHW`变成`N,hidden_dim,n_h*n_w`,conv的stride=patch_size
x = x.reshape(n, self.hidden_dim, n_h * n_w)
x = x.permute(0, 2, 1)
return x
- cls_token
再将_process_input处理后的数据与self.cls_token合并,得到shape为N, n_h*n_w+1, hidden_dim的序列作为编码器的输入。
batch_class_token = self.class_token.expand(n, -1, -1) # (1,1,hidden_dim) ->(N,1,hidden_dim)
x = torch.cat([batch_class_token, x], dim=1) # `N, n_h*n_w+1, hidden_dim`
- position_embedding
编码器是标准的多头注意力Transoformer,在torchvision提供的模型中位置嵌入使用的可学习的参数,位置参数直接和输入数据相加
pos_embedding = torch.nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
input = input + pos_embedding
- MSA
多头注意力的实现使用了torch.nn.MultiheadAttention模块,
CLASStorch.nn.MultiheadAttention(embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None, vdim=None,
batch_first=False,
device=None, dtype=None)
embed_dim是模型的维度,num_heads是头数kdim/vdim是Query权重和Value权重的维度,按transformer论文的介绍Query.weight.shape=(embed_dim, kdim),但是在pytorch目前实现的MultiheadAttention中必须kdim==vdim==embed_dim,否则计算将报错batch_first控制支持的输入数据shape,为True是支持N,L,hidden_dim维度的输入。
一个使用的例子:
embed_dim = 512
num_heads = 16
multihead_attn = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, kdim=512, vdim=512)
x = torch.randn(2,12, 512)
query, key, value = x, x, x
output = multihead_attn(query,key,value)
print(output[0].shape) # (1, 12, 512)
回到ViT,经过编码器后得到的输出shape为(N,L+1,hidden_dim),然后取输出序列cls_token对应位置的数据,作为特征送入线性分类器即可得到分类结果。
linear_classifier = nn.Linear(hidden_dim, num_classes)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0] # shape(N, hidden_dim)
linear_classifier(x) # shape(N, num_class)
torchvision 在0.13版本实现了ViT模型。