1、从卷积到自注意力的范式革命
Vision Transformer(ViT)的诞生标志着计算机视觉领域的一次重大范式转变。在2020年之前,卷积神经网络(CNN)几乎统治了所有视觉任务,从图像分类到目标检测,卷积操作凭借其平移不变性和局部感受野的特性,天然适合处理图像数据。然而,Google研究团队在论文《An Image is Worth 16x16 Words》中提出了一个颠覆性的问题:如果我们将图像视为文本序列,用处理语言的方式处理图像,会发生什么?这一设想的理论基础来自Transformer架构在自然语言处理领域的巨大成功。Transformer的核心优势在于自注意力机制(Self-Attention),它能够直接建模序列中任意两个位置之间的依赖关系,而不像RNN那样受限于序列顺序,也不像CNN那样局限于局部邻域。这种全局建模能力理论上同样适用于图像数据,因为图像中的远距离依赖关系——比如一只猫的耳朵和尾巴之间的关联——往往对准确识别至关重要。ViT的解决方案简洁而优雅:将图像均匀切割成固定大小的小块(patch),每个patch被视为一个"视觉单词",然后通过线性投影将这些patch映射为向量序列,输入到标准的Transformer编码器中进行处理。这种设计彻底抛弃了卷积操作,完全依赖自注意力机制来学习图像特征。实验结果表明,当在足够大规模的数据集(如ImageNet-21k或JFT-300M)上预训练后,ViT能够超越同等规模的CNN模型,证明了纯Transformer架构在视觉领域的可行性。这一发现开启了视觉Transformer研究的浪潮,催生了后续众多改进工作。
2、核心架构设计与工作机制
ViT的核心架构遵循Transformer编码器的经典设计,但针对图像输入做了关键适配。整个流程可以分为四个关键步骤:图像分块与序列化、位置编码注入、多头自注意力特征提取、全局分类输出。首先,输入图像被切分成N个不重叠的patch,例如224×224的图像使用16×16的patch大小会得到196个patch。每个patch通过一个可训练的线性投影层(实际实现中常用卷积层代替)映射为固定维度的嵌入向量,这一过程类似于NLP中的词嵌入。为了保留patch之间的空间位置关系,ViT引入了可学习的位置编码,将其直接加到patch嵌入上。这里有一个重要的设计选择:ViT使用了一维位置编码而非二维,意味着模型需要自行学习patch之间的二维空间结构。此外,借鉴BERT的[CLS]标记设计,ViT在序列开头添加了一个额外的可学习分类标记(class token),这个标记不携带任何图像信息,但在经过所有Transformer层后,其对应的输出向量会被送入分类头进行最终预测。这种设计使得分类token能够聚合来自所有patch的全局信息。每个Transformer层内部包含两个核心子层:多头自注意力(MSA)和多层感知机(MLP),每个子层前都应用层归一化(LayerNorm),并使用残差连接来缓解深层网络的梯度消失问题。在MSA中,输入序列被投影为查询(Q)、键(K)、值(V)三组向量,通过计算Q与K的点积相似度来得到注意力权重,再用这些权重对V进行加权求和,从而实现"让模型自己决定关注哪些位置"的效果。多头机制允许模型从不同表示子空间同时学习多种注意力模式。MLP则通常采用两层的全连接网络,中间使用GELU激活函数,隐藏层维度通常是嵌入维度的4倍。通过堆叠12层或更多这样的编码器层,ViT能够逐步提取从低级纹理到高级语义的多层次视觉特征。
3、数据依赖性与后续发展演进
尽管ViT在架构设计上展现出了优雅的简洁性,但其最初版本存在一个显著的局限性:严重依赖大规模预训练数据。在没有大规模数据集的情况下,ViT的性能往往不及ResNet等成熟CNN架构。这是因为CNN内置的归纳偏置(inductive bias)——如平移不变性和局部性——使其天然具备一定的视觉先验知识,能够从较少样本中学习有效的特征表示。而ViT作为一个通用架构,缺乏这些针对图像设计的先验假设,需要从海量数据中自行"发现"这些规律。这一特性使得ViT在ImageNet-1K(约130万张图片)上训练时效果不如CNN,但当预训练数据扩展到JFT-300M(约3亿张图片)时,ViT展现出强大的可扩展性和性能上限。针对数据效率问题,后续研究从多个方向进行了改进。DeiT(Data-efficient Image Transformers)引入了知识蒸馏策略,用一个训练好的CNN作为教师模型指导ViT训练,显著降低了ViT对数据量的需求,使在ImageNet-1K上从头训练ViT成为可能。Swin Transformer则借鉴了CNN的层次化设计思想,提出了移位窗口注意力机制,在不同层之间逐步合并相邻patch来构建特征金字塔,同时通过窗口划分将自注意力的计算复杂度从二次降为线性,使得Transformer能够高效处理高分辨率图像。PVT(Pyramid Vision Transformer)和T2T-ViT等工作也在层次化结构和patch处理策略上做了创新。这些改进使得视觉Transformer家族在图像分类、目标检测、语义分割等几乎所有视觉任务上都取得了最优性能。如今,视觉Transformer已经从一个实验性的想法发展为计算机视觉领域的主流架构,其在多模态学习(如CLIP)、图像生成(如DIT)等前沿方向也展现出巨大潜力,深刻改变了我们对视觉模型设计的基本认知。
4、Vision Transformer最小体量的PyTorch代码复现
import torch
import torch.nn as nn
import torch.nn.functional as F
# ==================== 1. 图像分块与嵌入模块 ====================
class PatchEmbedding(nn.Module):
"""
将输入图像切割成小块(patch),并将每个patch映射为向量
例如:224x224的图像,使用16x16的patch,会得到196个patch
"""
def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768):
"""
参数说明:
image_size: 输入图像的尺寸(正方形,所以只需要一个数)
patch_size: 每个patch的大小
in_channels: 输入图像的通道数(RGB图像为3)
embed_dim: 每个patch映射后的向量维度(论文中ViT-Base为768)
"""
super().__init__()
# 计算一共有多少个patch
self.num_patches = (image_size // patch_size) ** 2 # 224/16=14, 14*14=196
# 使用卷积层实现分块和嵌入:卷积核大小=patch_size,步长=patch_size
# 这样一次卷积操作就能同时完成图像分块和线性映射
self.proj = nn.Conv2d(
in_channels, # 输入通道数(RGB为3)
embed_dim, # 输出通道数(embed_dim)
kernel_size=patch_size, # 卷积核大小等于patch大小
stride=patch_size # 步长等于patch大小,这样patch之间不重叠
)
def forward(self, x):
"""
前向传播
输入x形状: (batch_size, 3, 224, 224)
输出形状: (batch_size, num_patches, embed_dim)
"""
batch_size = x.shape[0]
# 通过卷积层: (B, 3, 224, 224) -> (B, 768, 14, 14)
x = self.proj(x)
# 展平patch维度: (B, 768, 14, 14) -> (B, 768, 196)
x = x.flatten(2)
# 转置: (B, 768, 196) -> (B, 196, 768)
# 这样每个patch对应一个768维的向量
x = x.transpose(1, 2)
return x
# ==================== 2. 多头自注意力机制 ====================
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制:让模型能够关注输入序列中不同位置的信息
这是Transformer的核心组件
"""
def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
"""
参数说明:
embed_dim: 嵌入维度
num_heads: 注意力头的数量(embed_dim必须能被num_heads整除)
dropout: dropout比率,防止过拟合
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 每个头的维度: 768/12=64
# 将embed_dim拆分成num_heads个head_dim
# 一次性生成Q、K、V三个矩阵,提高效率
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
# 注意力后的线性投影
self.proj = nn.Linear(embed_dim, embed_dim)
# Dropout层
self.attn_drop = nn.Dropout(dropout)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
"""
前向传播
输入x形状: (batch_size, num_patches + 1, embed_dim) # +1是分类token
输出形状: 与输入相同
"""
batch_size, num_tokens, embed_dim = x.shape
# 1. 生成Q、K、V: (B, N, 768) -> (B, N, 2304)
qkv = self.qkv(x)
# 2. 重塑形状,分出多头: (B, N, 2304) -> (B, N, 3, 12, 64) -> (3, B, 12, N, 64)
qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # 将3移到最前面
# 3. 分离Q、K、V: 每个形状为 (B, 12, N, 64)
q, k, v = qkv[0], qkv[1], qkv[2]
# 4. 计算注意力分数: (B, 12, N, N)
# Q @ K^T / sqrt(d_k)
scale = self.head_dim ** 0.5 # 缩放因子,防止点积过大
attn = (q @ k.transpose(-2, -1)) / scale
# 5. Softmax归一化
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
# 6. 用注意力权重加权V: (B, 12, N, 64)
x = attn @ v
# 7. 合并多头: (B, 12, N, 64) -> (B, N, 768)
x = x.transpose(1, 2).reshape(batch_size, num_tokens, embed_dim)
# 8. 最终投影
x = self.proj(x)
x = self.proj_drop(x)
return x
# ==================== 3. Transformer编码器层 ====================
class TransformerEncoderLayer(nn.Module):
"""
Transformer编码器层:由多头自注意力和前馈网络组成
使用残差连接和层归一化
"""
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout=0.1):
"""
参数说明:
embed_dim: 嵌入维度
num_heads: 注意力头数
mlp_ratio: MLP隐藏层扩展比率(通常是4)
dropout: dropout比率
"""
super().__init__()
# 层归一化1(用于注意力前)
self.norm1 = nn.LayerNorm(embed_dim)
# 多头自注意力
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
# 层归一化2(用于MLP前)
self.norm2 = nn.LayerNorm(embed_dim)
# 前馈网络(MLP): 两层全连接,中间有GELU激活
hidden_dim = int(embed_dim * mlp_ratio) # 隐藏层维度: 768*4=3072
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden_dim), # 扩展
nn.GELU(), # GELU激活函数
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim), # 还原
nn.Dropout(dropout)
)
def forward(self, x):
"""
前向传播,使用残差连接(Pre-Norm结构)
输入输出形状相同: (B, N, embed_dim)
"""
# 注意力子层:x + Attention(Norm(x))
x = x + self.attn(self.norm1(x))
# MLP子层:x + MLP(Norm(x))
x = x + self.mlp(self.norm2(x))
return x
# ==================== 4. Vision Transformer主模型 ====================
class VisionTransformer(nn.Module):
"""
Vision Transformer完整模型
结构:图像分块 -> 添加分类token和位置编码 -> 多个Transformer层 -> 分类头
"""
def __init__(self,
image_size=224, # 输入图像大小
patch_size=16, # patch大小
in_channels=3, # 输入通道数
num_classes=10, # 分类类别数(CIFAR-10为10类)
embed_dim=768, # 嵌入维度
depth=12, # Transformer层数
num_heads=12, # 注意力头数
mlp_ratio=4, # MLP扩展比率
dropout=0.1): # dropout比率
super().__init__()
# 图像分块嵌入
self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches
# 分类token:一个可学习的向量,用于最后的分类
# 形状: (1, 1, 768),会广播到batch中每个样本
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置编码:让模型知道patch的位置信息
# 形状: (1, 197, 768) -> 1个cls_token + 196个patch
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# 分类token的dropout
self.pos_drop = nn.Dropout(dropout)
# 堆叠多个Transformer编码器层
# 使用ModuleList管理多个层
self.blocks = nn.ModuleList([
TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# 最后的层归一化
self.norm = nn.LayerNorm(embed_dim)
# 分类头:只使用cls_token对应的输出进行分类
self.head = nn.Linear(embed_dim, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
"""初始化模型权重"""
# 使用截断正态分布初始化位置编码和分类token
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
# 对线性层和层归一化进行初始化
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def forward(self, x):
"""
完整前向传播
输入x形状: (batch_size, 3, 224, 224)
输出形状: (batch_size, num_classes)
"""
batch_size = x.shape[0]
# 1. 图像分块嵌入: (B, 3, 224, 224) -> (B, 196, 768)
x = self.patch_embed(x)
# 2. 添加分类token: 将cls_token扩展到batch_size个样本
# (1, 1, 768) -> (B, 1, 768)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 3. 拼接cls_token和patch嵌入: (B, 197, 768)
x = torch.cat([cls_tokens, x], dim=1)
# 4. 添加位置编码: 简单的加法
x = x + self.pos_embed
x = self.pos_drop(x)
# 5. 通过所有Transformer层
for block in self.blocks:
x = block(x)
# 6. 最后的层归一化
x = self.norm(x)
# 7. 取cls_token对应的输出进行分类
# x[:, 0]取所有batch的第一个token(cls_token): (B, 768)
x = self.head(x[:, 0])
return x
# ==================== 5. 使用示例 ====================
def create_vit_tiny():
"""
创建一个更小的ViT版本,适合在小数据集上训练
这个轻量版参数量更少,训练更快
"""
return VisionTransformer(
image_size=224, # 输入图像大小
patch_size=16, # 16x16的patch
in_channels=3, # RGB图像
num_classes=10, # 假设是10分类任务(如CIFAR-10)
embed_dim=192, # 降低嵌入维度(原本768)
depth=12, # 保持12层
num_heads=3, # 减少注意力头数
mlp_ratio=4, # MLP扩展比率
dropout=0.1 # Dropout比率
)
# ==================== 6. 测试代码 ====================
if __name__ == "__main__":
# 创建一个小型ViT模型
model = create_vit_tiny()
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"模型总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")
# 创建一个模拟输入:batch_size=2, 3通道, 224x224图像
dummy_input = torch.randn(2, 3, 224, 224)
# 前向传播测试
output = model(dummy_input)
print(f"输入形状: {dummy_input.shape}")
print(f"输出形状: {output.shape}")
print(f"\n模型结构:\n{model}")
# 验证各个组件的维度
print("\n=== 维度变化演示 ===")
print(f"输入图像: {dummy_input.shape}")
# 分块嵌入
patches = model.patch_embed(dummy_input)
print(f"分块后: {patches.shape} # (batch, 196个patch, 768维)")
# 添加分类token和位置编码
cls_tokens = model.cls_token.expand(2, -1, -1)
with_pos = torch.cat([cls_tokens, patches], dim=1)
print(f"添加token后: {with_pos.shape} # (batch, 197个token, 768维)")4.1、核心组件说明:
PatchEmbedding: 将图像切分成固定大小的块,每个块映射为向量
MultiHeadAttention: 多头自注意力机制,让模型关注不同位置的信息
TransformerEncoderLayer: 标准的Transformer层,包含注意力和前馈网络
VisionTransformer: 完整模型,组合所有组件
4.2、关键设计思想:
分块处理: 像处理文本一样处理图像,将图像块当作"单词"
位置编码: 由于没有RNN/CNN的固有位置信息,需要显式添加位置编码
分类token: 借鉴BERT的[CLS] token,用于聚合全局信息进行分类
残差连接: 每个子层都有残差连接,帮助训练深层网络