date
Mar 2, 2026
slug
2026-03-02-the-step-by-step-implementation-of-a-simple-ViT-modal
status
Published
tags
ViT
Transformer
summary
本文针对简单的MNIST手写数字数据集识别的需求,完成了一个最简单的Vision Transformer模型的实现、训练和验证测试,建立对Vision Transformer模型实现流程的完整理解。
MNIST手写数字数据集是最简单的机器视觉数据集,基于MNIST实现一个Vision Transformer模型来实现手写数字字符的识别,难度不会太大,对于模型训练所需要的数据以及算力资源要求也不高,因此通过训练一个MNIST数据集的ViT识别模型,是一个绝佳的入门Vision Transformer模型的实验。
type
Post
category
AI
AI summary
从头实现一个Vision Transformer(ViT)模型
本文针对简单的MNIST手写数字数据集识别的需求,完成了一个最简单的Vision Transformer模型的实现、训练和验证测试,建立对Vision Transformer模型实现流程的完整理解。
MNIST手写数字数据集是最简单的机器视觉数据集,基于MNIST实现一个Vision Transformer模型来实现手写数字字符的识别,难度不会太大,对于模型训练所需要的数据以及算力资源要求也不高,因此通过训练一个MNIST数据集的ViT识别模型,是一个绝佳的入门Vision Transformer模型的实验。

关于Vision Transformer模型的整体架构,可以参加另外一篇笔记[[一文入门Vision in Transformer(ViT)模型的架构]]。

注:参考资料1中提供了一个完整的、对MNIST手写字符集进行识别的Vision Transformer模型,本文实际上是对以上资料的学习记录,并增加自己对于实现过程的理解。
如[[一文入门Vision in Transformer(ViT)模型的架构]]这篇笔记所总结的,ViT模型的架构和工作流程大致可以分为以下几个阶段:图像切片、添加Cls Token、位置编码、基于多头自注意力机制的Transformer Encoder、最终的训练头。因此,本文仍然按照以上结构来实现一个完整的Vision Transformer结构,并使用MNIST数据集对其进行训练以及最后的测试验证。
1. 导入必要的包
- 导入torchvision.transforms用于实现对MNIST字符集图像分辨率的resize,确保模型输入分辨率是patch分辨率的整数倍;以及把MNIST字符集的图像数据转换为Tensor格式。
- 训练过程所使用的优化器使用torch.optim中所包含的Adam。
- 利用torchvision.datasets.mnist从Pytorch官网自动下载MNIST数据集。
2. Patch Embeddings
这一步实现把输入图像切片为固定大小的patch,并且序列化。如[[一文入门Vision in Transformer(ViT)模型的架构]]所提到的,对于这部分功能,一般是通过一个步长(Stride)等于卷积核大小(kernel_size)的 2D 卷积(Conv2d)来实现的。
以上的 Conv2d 操作实际上是一个卷积投影操作:
- kernel_size=patch_size :是用 patch_size×patch_size 的卷积核提取每个 patch
- stride=patch_size :步长等于 patch 尺寸,从而实现非重叠的切分动作
- 输出通道数=d_model :每个 patch 被映射为 d_model 维的向量,实现每个patch的向量化
patch embedding的完整代码如下:
在以上的forward操作中,self.linear_project(x)这一步把一张输入图像的维度从(C,H,W)转换为(d_model,P_col,P_row);

x.flatten(2)这一步把所有的patch展平,得到P个序列化Patch,每个patch的维度长度都是d_model,这实际上就是patch embedding操作的输出结果;

最后的x = x.transpose操作就只是交换输出数据的维度,以适配 Transformer 架构的标准输入格式。

3. Cls Token与位置编码
为Batch中的每张图片的Patch列表的开头,增加一个Cls Token的逻辑相对比较简单,就是新建一个长度与Patch向量相同的cls token,然后把这个token放到patch向量的开头作为其0号token即可。
除了Cls Token以外,还需要为Patch向量中增加其位置编码。每个位置编码对于它所代表的位置来说都是唯一的,这使得模型可以识别每个Patch的位置。为了将位置编码添加到Patch向量中,它们必须具有相同的维度 d_model。位置向量的构建公式如下:

因此总的来说,这一步包含两个步骤:为Batch中每张图片的Patch向量列表中的初始位置增加一个Cls token;为所有的向量增加其位置编码。完整代码如下:
4. 单头及多头注意力模块的实现
有关多头注意力模块的理论以及计算流程等方面的内容,在[[如何理解Transformer架构中的多头注意力机制?]]一文中已有非常详细的总结和解释,在此不再赘述。
单头注意力机制的计算公式如下:
可结合上面的公式以及代码中的注释来理解单头注意力的计算逻辑和流程,基本上就是以上公式的一步步计算而已:
基于以上的单头注意力模块,实现多头注意力模块就比较简单了,就是对多个注意力模块独立计算,把计算结果合并起来,最后再通过一个线性层把多个单头输出的特征进一步融合并输出:
以上单头和多头注意力模块的流程和结构如下图所示:

5. Transformer Encoder
接下来是完整的Transformer Encoder的实现。Transformer Encoder的结构如下所示:

其对应的代码如下:
6.整个Vision Transformer模型
最终实现的整个Vision Transformer的模型架构如下:

- 输入图像首先在PatchEmbedding模块中进行切分和向量化、序列化的操作。
- 序列化的Patch向量再在PositionalEncoding模块中增加额外的Cls Token,以及位置向量。
- 以上数据再经过连续N个TransformerEncoder模块的处理,反复通过多头注意力模块提取和融合图像中的特征。
- 最后在一个分类头处理后通过Softmax输出分类识别的概率。
完整Vision Transformer模型的实现如下:
7. 训练和测试
最后对以上构建出来的模型进行训练,以及对训练出来的模型使用MNIST的测试数据集进行验证。
超参数设置
对模型进行训练的超参数设置如下:
训练以及测试数据集的准备
训练和测试使用从Pytorch官网下载的MNIST数据集,使用以下代码下载数据集并准备dataloader:
模型训练
接下来就是对已经构建出来的模型参数进行训练,整个训练的执行流程实际上与[[基于Pytorch实现手写数字识别的卷积神经网络]]一文所描述的基于卷积神经网络的手写数字识别的训练流程大同小异:
训练了15个epoch,从打印的损失值可以看到其损失在稳定向下收敛过程中:
模型测试
最后就是使用以上训练好的模型,对MNIST的测试数据集进行测试:
经过上面的15轮测试,模型对测试集数据的测试准确度就达到了93%。
参考资料
- Building a Vision Transformer Model From Scratch | by Matt Nguyen | Toward Humanoids | Medium
尝试从底层原理的角度去理解和解释技术问题:音视频/摄像头/智能家居/蓝牙/WiFi/无线通信/AI。
敬请关注微信公众号:Pavel Han。