SimMIM:一个类BERT的计算机视觉的预训练框架

avatar
作者
猴君
阅读量:1

1、前言

呃…好久没有写博客了,主要是最近时间比较少。今天来做一期视频+博客的内容。本文主要讲SimMIM,它是一个将计算机视觉(图像)进行自监督训练的框架。

原论文:SimMIM:用于掩码图像建模的简单框架 (arxiv.org)

代码实现:https://github.com/microsoft/SimMIM

视频:【SimMIM:计算机视觉的随机掩码预训练-哔哩哔哩】

Demo:

在这里插入图片描述

2、引入

2.1、NLP自然语言处理

我们之前讲过了各种各样的NLP语言处理的模型框架,比如BERT,GPT(没有做相关的博文,只有视频)。它们都是基于在一大堆没有标签的文本上进行训练。然后才进行下游微调(GPT2及以后没有微调)。

为什么需要做预训练?在互联网上,大多数的文本数据都是没有标签的。而有标签的文本则相对较少。为了能够充分利用这些无标签文本,我们可以在这些文本上进行预训练,然后学习得到文字的表征。进而才执行各种下游任务

2.2、图像视觉

那人们就想,是否在图像视觉领域,也可以在无标签的图像上进行预训练呢?毕竟互联网上,也存在大量的无标签图像。不能把这些图像充分利用起来,实在可惜。而SimMIM,就是一个在无标签图像上预训练的框架

3、方法

同BERT一样,最直观的做法,当燃是使用随机掩码,我们只需要掩码掉其中一些像素值。然后,把它送进网络。最后预测出那些被掩码掉的部分,再把他们做损失即可。然后进行优化更新。

然而,这种做法,却存在几个很重要的问题:

3.1、问题

① 在图像领域,图像往往具有很强的局部性。也就是说,彼此接近的像素点之间,往往是高度相关的。这样的话,模型及其未必需要学习到像素的语义信息。它只需要复制旁边的没有被掩码的像素点,填充到被掩码的部分。就可以实现预测(彼此接近的像素点之间具有高度的相关性)

② 文字是人类生成的高级概念,而视觉信号,是相对原始、低级的。那么就有这么一个问题,低级信号的预测是否对高级视觉识别任务有用?

③ 文本是离散的,而视觉信号是连续的。目前还不清楚基于分类的掩码语言建模方法如何能够很好地处理连续的视觉信号。

3.2、问题解决

对于问题①:

我们可以使用较高的掩码比例以及较大的patch(此处指掩码块的大小,不是图像分块),让模型能够找到附近可见像素点的可能性减少。从而让模型无法直接从附近像素复制来填充被掩码的值。

作者经过实验,发现当掩码块大小为32时,掩码比例为10%~70%,就可以取得相对较好的性能;对于一个大小为8的掩码块,就需要80%的掩码比例才能达到良好的效果。

②: 使用原始像素回归任务。回归任务很好地符合视觉信号的连续特性,具有可排序性。

③: 采用了一个轻量级的预测头(如线性层)。使用轻量级的预测头带来了预训练的显著加速。虽然较重的头或较高的分辨率通常会导致更大的生成能力,但这种更大的能力不一定有利于下游的微调任务。

4、SimMIM

总的来说,模型图可以表达成这样

在这里插入图片描述

首先,将一张图像按照一定掩码块大小,按比例随机掩码掉一些块。然后送给一个编码器(比如ViT,Swin Transformer)。然后再加上一个预测头(one-layer prediction Head)。得到输出结果后。把预测的结果与真实的结果作损失(比如 L 1 L_1 L1损失)。

4.1、掩码

在论文里面提到,不论是将ViT还是Swin作为编码器结构,他们都采用32x32的掩码块。

论文采用了各种掩码规则进行消融对比。最后发现,采用随机掩码,块大小为32,掩码比例为0.5。取得的效果最好

在这里插入图片描述

实验对比

在这里插入图片描述

4.2、预测头

预测头的作用,就是把维度信息,重新映射成图像大小。

以Swin Transformer为例,我们最后的输出维度是 H 32 × W 32 × 8 C \frac{H}{32}\times\frac{W}{32}\times8C 32H×32W×8C

那么最后的预测头,就可以使用一个1x1的卷积层。把通道数从8C转化成32x32x3=3072维度。

这样一来,我们就得到维度为 H 32 × W 32 × 32 × 32 × 3 \frac{H}{32}\times\frac{W}{32}\times 32\times 32\times 3 32H×32W×32×32×3。就可以把它reshape成 H × W × 3 H\times W\times 3 H×W×3。也就是图像的大小。然后,把它与真实的图像作损失
L = 1 Ω ( x M ) ∥ y M − x M ∥ 1 L=\frac{1}{\Omega(x_M)}\Vert y_M-x_M \Vert_1 L=Ω(xM)1yMxM1
其 中 x , y ∈ R 3 H W × 1 x,y ∈ R^{3HW}×1 x,yR3HW×1 分 别 为 输 入 的RGB值和预测值;M表示掩码像素的集合;Ω(·)是元素的数量。

在实验中,还考虑了 L 2 L_2 L2 s m o o t h − L 1 smooth-L_1 smoothL1损失函数,它们的表现相似,默认采用 L 1 L_1 L1损失函数。

5、其他细节

在代码中,其实我还发现了有一个分类头。与BERT一样,在下游任务(比如图像分类)的时候,我们可以那他来作分类微调等等(非必须)。

另外还有一些,就是训练的细节了。那一部分很简单,就是一些超参数的配置等等。我也不可能一一全部列举出来,大家可以看论文,或者看代码里面就知道了。

6、结束

好了,本篇文章到此为止,如有问题,还望指出。阿里嘎多!

在这里插入图片描述

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!