PyTorch基于注意力的目标检测模型DETR

avatar
作者
筋斗云
阅读量:0

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

目标检测是计算机视觉领域的一个重要任务,它的目标是在图像或视频中识别并定位出特定的对象。在这个过程中,需要确定对象的位置和类别,以及可能存在的多个实例。

DETR模型通过端到端的方式进行目标检测,即从原始图像直接检测出目标的位置和类别,而不需要进行区域提议或特征金字塔等步骤。

DETR模型的核心思想是将目标检测任务转换为一个序列到序列的问题。它将输入图像视为一个序列,并使用Transformer编码器将其转换为一种可被解码器理解的形式。具体来说,DETR模型使用CNN来提取图像特征,然后将其输入Transformer编码器中进行处理。再使用一个Transformer解码器来逐步解码出目标的位置和类别。完整的DETR的架构如图13-11所示。

图13-11  完整的DETR模型架构

下面借用在13.2节中实现的DETR目标检测模型进行讲解。完整的DETR模型代码如下:

import torch from torch import nn from torchvision.models import resnet50  class DETR(nn.Module):     def __init__(self,num_classes = 92,hidden_dim=256,nheads=8,num_encoder_layers=6,num_decoder_layers=6):         super().__init__()         #创建ResNet-50的骨干(backbone)网         with torch.no_grad():             self.backbone = resnet50()             #清除ResNet-50骨干网最后的全连接层             del self.backbone.fc         #创建转换层,1×1的卷积,主要起到改变通道大小的作用         self.conv = nn.Conv2d(2048,hidden_dim,1)         #利用PyTorch内嵌的类创建Transformer实例         self.transformer = nn.Transformer(hidden_dim,nheads,num_encoder_layers,num_decoder_layers)         #预测头,多出的类别用于预测non-empty slots         self.linear_class = nn.Linear(hidden_dim,num_classes)         self.linear_bbox = nn.Linear(hidden_dim,4)         # 输出检测槽编码(object queries)         self.query_pos = nn.Parameter(torch.rand(100,hidden_dim))         #可学习的位置编码,用于指导输入图形的坐标         self.row_embed = nn.Parameter(torch.rand(50,hidden_dim//2))         self.col_embed = nn.Parameter(torch.rand(50,hidden_dim//2))         self._reset_parameters()      def forward(self,inputs):         #将ResNet-50网络作为backbone         x = self.backbone.conv1(inputs)                x = self.backbone.bn1(x)                         x = self.backbone.relu(x)         x = self.backbone.maxpool(x)               x = self.backbone.layer1(x)                      x = self.backbone.layer2(x)                      x = self.backbone.layer3(x)                      x = self.backbone.layer4(x)     	#将ResNet-50网络作为backbone          #从2048维度转换到可被Transformer接受的256维特征平面         h = self.conv(x)                                                 #(1,2048,25,34)->(1,hidden_dim,25,34)         # 构建位置编码         B,C,H,W = h.shape         #创建一个可训练的与输入向量同样维度的位置向量,与原始的DETR的不同之处在于这里的位置向量是可训练的         pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H,1,1),self. row_embed[:H].unsqueeze(1).repeat(1,W,1),],dim=-1).flatten(0,1).unsqueeze(1) 		 	   #将图像特征与位置信息进行合并         src = pos+0.1*h.flatten(2).permute(2,0,1)         #创建查询函数         tgt = self.query_pos.unsqueeze(1).repeat(1,B,1)         #通过Transformer继续前向传播         #参数1:(h*w,batch_size,256),参数2:(100,batch_size,hidden_dim)         #输出:(hidden_dim,100)-->(100,hidden_dim)         h = self.transformer(src,tgt).transpose(0,1)         #将Transformer的输出投影到分类标签及边界框         return {'pred_logits':self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}      def _reset_parameters(self):         for p in self.parameters():             if p.dim() > 1:                 torch.nn.init.xavier_uniform_(p) 

从上面模型架构的实现代码上来看,整体DETR设计较为简单,可以分为3个主要部分:backbone、Transfomer和FFN。

1. backbone组件

backbone是DETR模型的第一部分,主要用于在图像上提取特征,生成特征图。这些特征图将作为输入传递给Transformer Encoder。backbone通常使用类似于ResNet或CNN模型来提取特征。

DETR将Resnet50作为backbone进行特征抽取,这样做的目的是可以直接使用PyTorch 2.0中提供的预训练模型和权重,从而节省了训练时间。

2. Transformer构成

Transformer是DETR模型的第二部分,它是由编码器和解码器构成,如图13-12所示。

编码器用于对backbone输出的特征图进行编码。这个编码过程主要是通过多头自注意力机制实现的。在DETR模型中,每个多头自注意力之前都使用了位置编码,这种位置编码方式可以帮助模型更好地理解图像中的空间信息。

图13-12  DETR中的Transformer组件

3. 分类器FFN

FFN一般使用两个全连接层作为分类器,其作用是对基于Transformer编码和查询后的特征向量进行分类计算,代码如下:

{'pred_logits':self.linear_class(h),'pred_boxes':self.linear_bbox(h).sigmoid()}

这里的self.linear_class和linear_bbox分别是对查询结果类别和位置的计算,分别用于预测分类和边界框回归。

以上就是对DETR模型的讲解。可以看到,DETR模型在架构设计上并没有太过于难懂的部分,可以认为是前面所学知识的集成。DETR在目标检测上的成功除了模型的设计外,还有一个重大创新就是开创性地提出了新的损失函数,目标检测中的损失函数通常由两部分组成:类别损失和边界框损失。对于类别损失,一般采用交叉熵损失函数,而在边界框损失方面,一般采用L1或L2损失函数。然而,DETR算法采用了不同的方式来计算类别损失和边界框损失。

DETR算法中的损失函数采用了基于二部图匹配的方式进行计算。具体来说,该算法首先将ground truth和预测的bounding box进行匹配,然后通过对比匹配结果和真实标签之间的差异来计算损失值。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!