论文链接:[2403.19967] Rewrite the Stars
github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars
CVPR2024 Rewrite the Stars论文揭示了star operation
优势 (Advantages)
高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)
- StarNet通过星操作(star operation)实现高维和非线性特征空间的映射,而无需增加计算复杂度。与传统的内核技巧(kernel tricks)类似,星操作能够在低维输入中隐式获得高维特征 (ar5iv)。
- 对于YOLO系列网络,这意味着在保持计算效率的同时,能够获得更丰富和表达力更强的特征表示,这对于目标检测任务中的精细特征捕获尤为重要。
高效网络设计 (Efficient Network Design)
- StarNet通过星操作实现了高效的特征表示,无需复杂的网络设计和额外的计算开销。其独特的能力在于能够在低维空间中执行计算,但隐式地考虑极高维的特征 (ar5iv)。
- 这使得StarNet可以作为YOLO系列网络的主干,提供高效的计算和更好的特征表示,有助于在资源受限的环境中实现更高的检测性能。
多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)
- 通过多层星操作,StarNet能够递归地增加隐式特征维度,接近无限维度。对于具有较大宽度和深度的网络,这种特性可以显著增强特征的表达能力 (ar5iv)。
- 对于YOLO系列网络,这意味着可以通过适当的深度和宽度设计,显著提高特征提取的质量,从而提升目标检测的准确性。
解决的问题 (Problems Addressed)
计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)
- StarNet通过星操作在保持计算复杂度较低的同时,实现了高维特征空间的映射。这解决了传统高效网络设计中计算复杂度与性能之间的权衡问题 (ar5iv)。
- YOLO系列网络需要在实时性和检测精度之间找到平衡,StarNet的高效特性正好契合这一需求。
特征表示的丰富性 (Richness of Feature Representation)
- 传统卷积网络在特征表示的高维非线性变换上存在一定局限性,而StarNet通过星操作实现了更丰富的特征表示 (ar5iv)。
- 在目标检测任务中,特别是对于小目标和复杂场景,丰富的特征表示能够显著提升检测效果,使得YOLO系列网络在这些场景中表现更佳。
简化网络设计 (Simplified Network Design)
- StarNet通过星操作提供了一种简化网络设计的方法,无需复杂的特征融合和多分支设计就能实现高效的特征表示 (ar5iv)。
- 对于YOLO系列网络,这意味着可以更容易地设计和实现高效的主干网络,降低设计和调试的复杂度。
1. 在上文提到的仓库中下载imagenet/starnet.py
2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量
3. 将class StarNet注册并且在__init__()函数中进行修改
4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数
""" Implementation of Prof-of-Concept Network: StarNet. We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]: - like NO layer-scale in network design, - and NO EMA during training, - which would improve the performance further. Created by: Xu Ma (Email: ma.xu1@northeastern.edu) Modified Date: Mar/29/2024 """ import torch import torch.nn as nn from timm.models.layers import DropPath, trunc_normal_ from typing import List, Sequence, Union # from timm.models.registry import register_model from mmyolo.registry import MODELS model_urls = { "starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar", "starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar", "starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar", "starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar", } class ConvBN(torch.nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True): super().__init__() self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups)) if with_bn: self.add_module('bn', torch.nn.BatchNorm2d(out_planes)) torch.nn.init.constant_(self.bn.weight, 1) torch.nn.init.constant_(self.bn.bias, 0) class Block(nn.Module): def __init__(self, dim, mlp_ratio=3, drop_path=0.): super().__init__() self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True) self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True) self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False) self.act = nn.ReLU6() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x1, x2 = self.f1(x), self.f2(x) x = self.act(x1) * x2 x = self.dwconv2(self.g(x)) x = input + self.drop_path(x) return x @MODELS.register_module() class StarNet(nn.Module): def __init__(self, base_dim=32, out_indices: Sequence[int] = (0, 1, 2), depths=[3, 3, 12, 5], mlp_ratio=4, drop_path_rate=0.0, num_classes=1000, **kwargs): super().__init__() self.num_classes = num_classes self.in_channel = 32 self.out_indices = out_indices self.depths = depths # stem layer self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6()) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth # build stages self.stages = nn.ModuleList() cur = 0 for i_layer in range(len(depths)): embed_dim = base_dim * 2 ** i_layer down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1) self.in_channel = embed_dim blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])] cur += depths[i_layer] self.stages.append(nn.Sequential(down_sampler, *blocks)) # head # self.norm = nn.BatchNorm2d(self.in_channel) # self.avgpool = nn.AdaptiveAvgPool2d(1) # self.head = nn.Linear(self.in_channel, num_classes) # self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear or nn.Conv2d): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.stem(x) ##记录stage的输出 outs = [] for i in range(len(self.depths)): x = self.stages[i](x) if i in self.out_indices: outs.append(x) return tuple(outs) @MODELS.register_module() def starnet_s1(pretrained=False, **kwargs): model = StarNet(24, (0, 1, 2), [2, 2, 8, 3], **kwargs) if pretrained: url = model_urls['starnet_s1'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model @MODELS.register_module() def starnet_s2(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [1, 2, 6, 2], **kwargs) if pretrained: url = model_urls['starnet_s2'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model @MODELS.register_module() def starnet_s3(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [2, 2, 8, 4], **kwargs) if pretrained: url = model_urls['starnet_s3'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model @MODELS.register_module() def starnet_s4(pretrained=False, **kwargs): model = StarNet(32, (0, 1, 2), [3, 3, 12, 5], **kwargs) if pretrained: url = model_urls['starnet_s4'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) return model # very small networks # @MODELS.register_module() def starnet_s050(pretrained=False, **kwargs): return StarNet(16, (0, 1, 2), [1, 1, 3, 1], 3, **kwargs) @MODELS.register_module() def starnet_s100(pretrained=False, **kwargs): return StarNet(20, (0, 1, 2), [1, 2, 4, 1], 4, **kwargs) @MODELS.register_module() def starnet_s150(pretrained=False, **kwargs): return StarNet(24, (0, 1, 2), [1, 2, 4, 2], 3, **kwargs) if __name__ == '__main__': model = StarNet() input_tensor = torch.randn(1, 3, 224, 224) outputs = model(input_tensor)
# Copyright (c) OpenMMLab. All rights reserved. from .base_backbone import BaseBackbone from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet from .csp_resnet import PPYOLOECSPResNet from .cspnext import CSPNeXt from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep from .yolov7_backbone import YOLOv7Backbone from .starnet import StarNet __all__ = [ 'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep', 'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet', 'YOLOv8CSPDarknet','StarNet' ]
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py'] # ========================Frequently modified parameters====================== # -----data related----- data_root = 'data/coco/' # Root path of data # Path of train annotation file train_ann_file = 'annotations/instances_train2017.json' train_data_prefix = 'train2017/' # Prefix of train image path # Path of val annotation file val_ann_file = 'annotations/instances_val2017.json' val_data_prefix = 'val2017/' # Prefix of val image path num_classes = 80 # Number of classes for classification # Batch size of a single GPU during training train_batch_size_per_gpu = 16 # Worker to pre-fetch data for each single GPU during training train_num_workers = 8 # persistent_workers must be False if num_workers is 0 persistent_workers = True # -----model related----- # Basic size of multi-scale prior box anchors = [ [(10, 13), (16, 30), (33, 23)], # P3/8 [(30, 61), (62, 45), (59, 119)], # P4/16 [(116, 90), (156, 198), (373, 326)] # P5/32 ] # -----train val related----- # Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs base_lr = 0.01 max_epochs = 300 # Maximum training epochs model_test_cfg = dict( # The config of multi-label for multi-class prediction. multi_label=True, # The number of boxes before NMS nms_pre=30000, score_thr=0.001, # Threshold to filter out boxes. nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold max_per_img=300) # Max number of detections of each image # ========================Possible modified parameters======================== # -----data related----- img_scale = (640, 640) # width, height # Dataset type, this will be used to define the dataset dataset_type = 'YOLOv5CocoDataset' # Batch size of a single GPU during validation val_batch_size_per_gpu = 1 # Worker to pre-fetch data for each single GPU during validation val_num_workers = 2 # Config of batch shapes. Only on val. # It means not used if batch_shapes_cfg is None. batch_shapes_cfg = dict( type='BatchShapePolicy', batch_size=val_batch_size_per_gpu, img_size=img_scale[0], # The image scale of padding should be divided by pad_size_divisor size_divisor=32, # Additional paddings for pixel scale extra_pad_ratio=0.5) # -----model related----- # The scaling factor that controls the depth of the network structure deepen_factor = 0.33 # The scaling factor that controls the width of the network structure widen_factor = 0.5 # Strides of multi-scale prior box strides = [8, 16, 32] num_det_layers = 3 # The number of model output scales norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config # -----train val related----- affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio loss_cls_weight = 0.5 loss_bbox_weight = 0.05 loss_obj_weight = 1.0 prior_match_thr = 4. # Priori box matching threshold # The obj loss weights of the three output layers obj_level_weights = [4., 1., 0.4] lr_factor = 0.01 # Learning rate scaling factor weight_decay = 0.0005 # Save model checkpoint and validation intervals save_checkpoint_intervals = 10 # The maximum checkpoints to keep. max_keep_ckpts = 3 # Single-scale training is recommended to # be turned on, which can speed up training. env_cfg = dict(cudnn_benchmark=True) ''' starnet_channel,base_dim,depths,mlp_ratio s1:24,[48, 96, 192],[2, 2, 8, 3],4 s2:32,[64, 128, 256],[1, 2, 6, 2],4 s3:32,[64, 128, 256],[2, 2, 8, 4],4 s4:32,[64, 128, 256],[3, 3, 12, 5],4 starnet_s050:16,[32,64,128],[1, 1, 3, 1],3 starnet_s0100:20,[40, 80, 120],[1, 2, 4, 1],4 starnet_s150:24,[48, 96, 192],[1, 2, 4, 2],3 ''' starnet_channel=[48, 96, 192] depths=[1, 2, 6, 2] # ===============================Unmodified in most cases==================== model = dict( type='YOLODetector', data_preprocessor=dict( type='mmdet.DetDataPreprocessor', mean=[0., 0., 0.], std=[255., 255., 255.], bgr_to_rgb=True), backbone=dict( ##s1 type='StarNet', base_dim=24, out_indices=(0,1,2), depths=depths, mlp_ratio=4, num_classes=num_classes, # deepen_factor=deepen_factor, # widen_factor=widen_factor, # norm_cfg=norm_cfg, # act_cfg=dict(type='SiLU', inplace=True) ), neck=dict( type='YOLOv5PAFPN', deepen_factor=deepen_factor, widen_factor=widen_factor, in_channels=starnet_channel, out_channels=starnet_channel, num_csp_blocks=3, norm_cfg=norm_cfg, act_cfg=dict(type='SiLU', inplace=True)), bbox_head=dict( type='YOLOv5Head', head_module=dict( type='YOLOv5HeadModule', num_classes=num_classes, in_channels=starnet_channel, widen_factor=widen_factor, featmap_strides=strides, num_base_priors=3), prior_generator=dict( type='mmdet.YOLOAnchorGenerator', base_sizes=anchors, strides=strides), # scaled based on number of detection layers loss_cls=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=loss_cls_weight * (num_classes / 80 * 3 / num_det_layers)), # 修改此处实现IoU损失函数的替换 loss_bbox=dict( type='IoULoss', focal=True, iou_mode='ciou', bbox_format='xywh', eps=1e-7, reduction='mean', loss_weight=loss_bbox_weight * (3 / num_det_layers), return_iou=True), loss_obj=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=loss_obj_weight * ((img_scale[0] / 640) ** 2 * 3 / num_det_layers)), prior_match_thr=prior_match_thr, obj_level_weights=obj_level_weights), test_cfg=model_test_cfg) albu_train_transforms = [ dict(type='Blur', p=0.01), dict(type='MedianBlur', p=0.01), dict(type='ToGray', p=0.01), dict(type='CLAHE', p=0.01) ] pre_transform = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='LoadAnnotations', with_bbox=True) ] train_pipeline = [ *pre_transform, dict( type='Mosaic', img_scale=img_scale, pad_val=114.0, pre_transform=pre_transform), dict( type='YOLOv5RandomAffine', max_rotate_degree=0.0, max_shear_degree=0.0, scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2), border_val=(114, 114, 114)), dict( type='mmdet.Albu', transforms=albu_train_transforms, bbox_params=dict( type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']), keymap={ 'img': 'image', 'gt_bboxes': 'bboxes' }), dict(type='YOLOv5HSVRandomAug'), dict(type='mmdet.RandomFlip', prob=0.5), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction')) ] train_dataloader = dict( batch_size=train_batch_size_per_gpu, num_workers=train_num_workers, persistent_workers=persistent_workers, pin_memory=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type=dataset_type, data_root=data_root, ann_file=train_ann_file, data_prefix=dict(img=train_data_prefix), filter_cfg=dict(filter_empty_gt=False, min_size=32), pipeline=train_pipeline)) test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), dict(type='YOLOv5KeepRatioResize', scale=img_scale), dict( type='LetterResize', scale=img_scale, allow_scale_up=False, pad_val=dict(img=114)), dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'), dict( type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param')) ] val_dataloader = dict( batch_size=val_batch_size_per_gpu, num_workers=val_num_workers, persistent_workers=persistent_workers, pin_memory=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, test_mode=True, data_prefix=dict(img=val_data_prefix), ann_file=val_ann_file, pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg)) test_dataloader = val_dataloader param_scheduler = None optim_wrapper = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=base_lr, momentum=0.937, weight_decay=weight_decay, nesterov=True, batch_size_per_gpu=train_batch_size_per_gpu), constructor='YOLOv5OptimizerConstructor') default_hooks = dict( param_scheduler=dict( type='YOLOv5ParamSchedulerHook', scheduler_type='linear', lr_factor=lr_factor, max_epochs=max_epochs), checkpoint=dict( type='CheckpointHook', interval=save_checkpoint_intervals, save_best='auto', max_keep_ckpts=max_keep_ckpts)) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0001, update_buffers=True, strict_load=False, priority=49) ] val_evaluator = dict( type='mmdet.CocoMetric', proposal_nums=(100, 1, 10), ann_file=data_root + val_ann_file, metric='bbox') test_evaluator = val_evaluator train_cfg = dict( type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=save_checkpoint_intervals) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop')