PointNeXt 源码阅读 (I) —— 注册机制与参数解析

avatar
作者
猴君
阅读量:0

Title: PointNeXt 源码阅读 (I) —— 注册机制与参数解析


文章目录


前言

学习了部分 PointNeXt 源码, 先记录一下, 以备忘.

本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.

相关注释和调试信息都是基于下面测试 session.

CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \ 				--cfg cfgs/s3dis/pointnext-s.yaml  mode=train 

I. 注册机制

所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.

1. 注册类 Registry

注册机制本身是通过注册类 class Registry 实现的, 其中关键方法有:

方法解释
__init__()类初始化, 其中也初始化了注册模块字典 self._module_dict = dict()
get(self, key)实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]
register_module(self, name=None, force=False, module=None)注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器
_register(cls)装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class, 然后没有任何其他处理而直接 return 了类定义 cls
build_from_cfg(cfg, registry, default_args=None)从配置字典构建模块/类实例, 实现由字符串生成模块/类实例

openpoints/utils/registry.py 中定义了 Registry 类, 添加注释如下.

# Acknowledgement: built upon mmcv import inspect import warnings from functools import partial import copy   class Registry:     """A registry to map strings to classes.     Registered object could be built from registry.     Example:         >>> MODELS = Registry('models')         >>> @MODELS.register_module()         >>> class ResNet:         >>>     pass         >>> resnet = MODELS.build(dict(NAME='ResNet'))     Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for     advanced useage.     Args:         name (str): Registry name.         build_func(func, optional): Build function to construct instance from             Registry, func:`build_from_cfg` is used if neither ``parent`` or             ``build_func`` is specified. If ``parent`` is specified and             ``build_func`` is not given,  ``build_func`` will be inherited             from ``parent``. Default: None.         parent (Registry, optional): Parent registry. The class registered in             children registry could be built from parent. Default: None.         scope (str, optional): The scope of registry. It is the key to search             for children registry. If not specified, scope will be the name of             the package where class is defined, e.g. mmdet, mmcls, mmseg.             Default: None.     """      def __init__(self, name, build_func=None, parent=None, scope=None):         self._name = name         self._module_dict = dict()         self._children = dict()         self._scope = self.infer_scope() if scope is None else scope         # self._scope = 'openpoints'          # self.build_func will be set with the following priority:         # 1. build_func         # 2. parent.build_func         # 3. build_from_cfg         if build_func is None:             if parent is not None:                 self.build_func = parent.build_func             else:                 self.build_func = build_from_cfg         else:             self.build_func = build_func         if parent is not None:             assert isinstance(parent, Registry)             parent._add_children(self)             self.parent = parent         else:             self.parent = None      def __len__(self):         return len(self._module_dict)      def __contains__(self, key):         return self.get(key) is not None      def __repr__(self):         format_str = self.__class__.__name__ + \                      f'(name={self._name}, ' \                      f'items={self._module_dict})'         return format_str      @staticmethod     def infer_scope():         """Infer the scope of registry.         The name of the package where registry is defined will be returned.         Example:             # in mmdet/models/backbone/resnet.py             >>> MODELS = Registry('models')             >>> @MODELS.register_module()             >>> class ResNet:             >>>     pass             The scope of ``ResNet`` will be ``mmdet``.         Returns:             scope (str): The inferred scope name.         """         # inspect.stack() trace where this function is called, the index-2         # indicates the frame where `infer_scope()` is called         filename = inspect.getmodule(inspect.stack()[2][0]).__name__         # filename = 'openpoints.models.build'            split_filename = filename.split('.')  # ['openpoints', 'models', 'build']         return split_filename[0]  # 'openpoints'      @staticmethod  # 返回函数的静态方法、声明一个静态方法     def split_scope_key(key):         """Split scope and key.         The first scope will be split from key.         Examples:             >>> Registry.split_scope_key('mmdet.ResNet')             'mmdet', 'ResNet'             >>> Registry.split_scope_key('ResNet')             None, 'ResNet'         Return:             scope (str, None): The first scope.             key (str): The remaining key.         """         split_index = key.find('.')           # 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值         if split_index != -1:             return key[:split_index], key[split_index + 1:]         else:             return None, key      @property     def name(self):         return self._name      @property     def scope(self):         return self._scope      @property     def module_dict(self):         return self._module_dict      @property     def children(self):         return self._children      def get(self, key):         # 实现从字符串到类的映射         # 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]         """Get the registry record.         Args:             key (str): The class name in string format.         Returns:             class: The corresponding class.         """         scope, real_key = self.split_scope_key(key)           # key = BaseSeg; scope = None; real_key = BaseSeg         if scope is None or scope == self._scope:             # get from self             if real_key in self._module_dict:                 return self._module_dict[real_key]         else:             # get from self._children             if scope in self._children:                 return self._children[scope].get(real_key)             else:                 # goto root                 parent = self.parent                 while parent.parent is not None:                     parent = parent.parent                 return parent.get(key)      def build(self, *args, **kwargs):         return self.build_func(*args, **kwargs, registry=self)      def _add_children(self, registry):         """Add children for a registry.         The ``registry`` will be added as children based on its scope.         The parent registry could build objects from children registry.         Example:             >>> models = Registry('models')             >>> mmdet_models = Registry('models', parent=models)             >>> @mmdet_models.register_module()             >>> class ResNet:             >>>     pass             >>> resnet = models.build(dict(NAME='mmdet.ResNet'))         """          assert isinstance(registry, Registry)         assert registry.scope is not None         assert registry.scope not in self.children, \             f'scope {registry.scope} exists in {self.name} registry'         self.children[registry.scope] = registry      def _register_module(self, module_class, module_name=None, force=False):         if not inspect.isclass(module_class):             raise TypeError('module must be a class, '                             f'but got {type(module_class)}')          if module_name is None:             module_name = module_class.__name__         if isinstance(module_name, str):             module_name = [module_name]         for name in module_name:             if not force and name in self._module_dict:                 raise KeyError(f'{name} is already registered '                                f'in {self.name}')             self._module_dict[name] = module_class      def deprecated_register_module(self, cls=None, force=False):         warnings.warn(             'The old API of register_module(module, force=False) '             'is deprecated and will be removed, please use the new API '             'register_module(name=None, force=False, module=None) instead.')         if cls is None:             return partial(self.deprecated_register_module, force=force)         self._register_module(cls, force=force)         return cls      def register_module(self, name=None, force=False, module=None):          # 装饰器         """Register a module.         A record will be added to `self._module_dict`, whose key is the class         name or the specified name, and value is the class itself.         It can be used as a decorator or a normal function.         Example:             >>> backbones = Registry('backbone')             >>> @backbones.register_module()             >>> class ResNet:             >>>     pass             >>> backbones = Registry('backbone')             >>> @backbones.register_module(name='mnet')             >>> class MobileNet:             >>>     pass             >>> backbones = Registry('backbone')             >>> class ResNet:             >>>     pass             >>> backbones.register_module(ResNet)         Args:             name (str | None): The module name to be registered. If not                 specified, the class name will be used.             force (bool, optional): Whether to override an existing class with                 the same name. Default: False.             module (type): Module class to be registered.         """         if not isinstance(force, bool):             raise TypeError(f'force must be a boolean, but got {type(force)}')         # NOTE: This is a walkaround to be compatible with the old api,         # while it may introduce unexpected bugs.         if isinstance(name, type):             return self.deprecated_register_module(name, force=force)          # raise the error ahead of time         if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):             raise TypeError(                 'name must be either of None, an instance of str or a sequence'                 f'  of str, but got {type(name)}')          # use it as a normal method: x.register_module(module=SomeClass)         # 正常调用 reister_module, 不是装饰情况         if module is not None:             self._register_module(                 module_class=module, module_name=name, force=force)             return module          # use it as a decorator: @x.register_module()         # 这是装饰器的包装函数 wrapper         # 装饰情况下的调用, cls 就是传递进来的需要被装饰的类          def _register(cls):             self._register_module(                 module_class=cls, module_name=name, force=force)             return cls          return _register  # 装饰器返回这个包装函数  def build_from_cfg(cfg, registry, default_args=None):     """Build a module from config dict.     Args:         cfg (edict): Config dict. It should at least contain the key "NAME".         registry (:obj:`Registry`): The registry to search the type from.     Returns:         object: The constructed object.     """     if not isinstance(cfg, dict):         raise TypeError(f'cfg must be a dict, but got {type(cfg)}')     if 'NAME' not in cfg:         if default_args is None or 'NAME' not in default_args:             raise KeyError(                 '`cfg` or `default_args` must contain the key "NAME", '                 f'but got {cfg}\n{default_args}')     if not isinstance(registry, Registry):         raise TypeError('registry must be an mmcv.Registry object, '                         f'but got {type(registry)}')      if not (isinstance(default_args, dict) or default_args is None):         raise TypeError('default_args must be a dict or None, '                         f'but got {type(default_args)}')      # if default_args is not None:     #     cfg = config.merge_new_config(cfg, default_args)      obj_type = cfg.get('NAME')   # 'BaseSeg'      if isinstance(obj_type, str):         obj_cls = registry.get(obj_type)           # <class 'openpoints.models.segmentation.base_seg.BaseSeg'>         # 按照名字字符串 从 self._module_dict 找出对应的 类/模块         # 实现从字符串到类的映射         if obj_cls is None:             raise KeyError(                 f'{obj_type} is not in the {registry.name} registry')     elif inspect.isclass(obj_type):         obj_cls = obj_type     else:         raise TypeError(             f'type must be a str or valid type, but got {type(obj_type)}')     try:         obj_cfg = copy.deepcopy(cfg)         if default_args is not None:             obj_cfg.update(default_args)          obj_cfg.pop('NAME')           # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项         # 'NAME' 已完成对类 obj_cls 的映射         return obj_cls(**obj_cfg)         # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)         # 又由于 BaseSeg 加了装饰器 @MODELS.register_module()          # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))         # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了         # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)      except Exception as e:         # Normal TypeError does not print class name.         raise type(e)(f'{obj_cls.__name__}: {e}') 

2. 类的注册

首先在 openpoints/models/build.py 中声明和定义了全局的注册类对象 MODELS, 称为注册器.

通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__/main()) 就建立.

from openpoints.utils import registry MODELS = registry.Registry('models') # 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量 # 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELS  def build_model_from_cfg(cfg, **kwargs):     """     Build a model, defined by `NAME`.     Args:         cfg (eDICT):      Returns:         Model: a constructed model specified by NAME.     """     return MODELS.build(cfg, **kwargs) 

也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models 下面在类定义前装饰了 @MODELS.register_module() 的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.

例如下面的 BaseSeg 类也会先注册到 MODELS._module_dict 中.

""" Author: PointNeXt """ import copy from typing import List import torch import torch.nn as nn import logging from ...utils import get_missing_parameters_message, get_unexpected_parameters_message from ..build import MODELS, build_model_from_cfg from ..layers import create_linearblock, create_convblock1d  # 为类 BaseSeg 加了装饰器 MODELS.register_module # 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg()) # 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立. # 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类  @MODELS.register_module()       class BaseSeg(nn.Module):     def __init__(self,                  encoder_args=None,                  decoder_args=None,                  cls_args=None,                  **kwargs):         super().__init__() 

调试过程中, 跟踪查看 MODELS._module_dict 可以发现已经注册了好多类.

MODELS._module_dict = { 'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>,  'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>,  'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>, 'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>,  'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>,  'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>,  'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>,  'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>,  'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>,  'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>,  'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>,  'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>,  'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>,  'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>,  'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>,  'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>,  'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>,  'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>,  'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>,  'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>,  'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>,  'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>,  'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>,  'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>,  'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>,  'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>,  'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>,  'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>,  'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>,  'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>,  'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>,  'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>,  'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>,  'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>,  'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>,  'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>,  'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'> } 

3. 注册应用

有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml 文件配置实现类实例的创建.

初略时序如下图所示:

examples/segmentation/main() openpoints/models/build.py class Registry openpoints/utils/registry.py 创建实例 registry.Registry('models') __init__(), self.build_func = build_from_cfg 全局对象 MODELS (注册器) build_model_from_cfg(cfg.model) MODELS.build(cfg, **kwargs) build(self, *args, **kwargs) build_from_cfg(cfg, registry, default_args=None) return obj_cls(**obj_cfg) [相当于 BaseSeg(**obj_cfg)] model examples/segmentation/main() openpoints/models/build.py class Registry
Fig 1. 利用注册器创建类对象 (深度神经网络模型) 的时序

其中由 .yaml 文件读取获得的配置字典变量 cfg 中存在 NAME 条目, 通过 registry.get(*) 就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).

细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:

def build_from_cfg(cfg, registry, default_args=None):     """Build a module from config dict.     Args:         cfg (edict): Config dict. It should at least contain the key "NAME".         registry (:obj:`Registry`): The registry to search the type from.     Returns:         object: The constructed object.     """     if not isinstance(cfg, dict):         raise TypeError(f'cfg must be a dict, but got {type(cfg)}')     if 'NAME' not in cfg:         if default_args is None or 'NAME' not in default_args:             raise KeyError(                 '`cfg` or `default_args` must contain the key "NAME", '                 f'but got {cfg}\n{default_args}')     if not isinstance(registry, Registry):         raise TypeError('registry must be an mmcv.Registry object, '                         f'but got {type(registry)}')      if not (isinstance(default_args, dict) or default_args is None):         raise TypeError('default_args must be a dict or None, '                         f'but got {type(default_args)}')      # if default_args is not None:     #     cfg = config.merge_new_config(cfg, default_args)      obj_type = cfg.get('NAME')   # 'BaseSeg'      if isinstance(obj_type, str):         obj_cls = registry.get(obj_type)           # <class 'openpoints.models.segmentation.base_seg.BaseSeg'>         # 按照名字字符串 从 self._module_dict 找出对应的 类/模块         # 实现从字符串到类的映射         if obj_cls is None:             raise KeyError(                 f'{obj_type} is not in the {registry.name} registry')     elif inspect.isclass(obj_type):         obj_cls = obj_type     else:         raise TypeError(             f'type must be a str or valid type, but got {type(obj_type)}')     try:         obj_cfg = copy.deepcopy(cfg)         if default_args is not None:             obj_cfg.update(default_args)          obj_cfg.pop('NAME')           # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项         # 'NAME' 已完成对类 obj_cls 的映射         return obj_cls(**obj_cfg)         # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)         # 又由于 BaseSeg 加了装饰器 @MODELS.register_module()          # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))         # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了         # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg) 

II. 参数解析

注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml 文件中的配置读入程序中.

1. 命令行解析

主程序部分先要将相关的 .yaml 文件读入并更新到 cfg 字典变量中, 注释如下.

if __name__ == "__main__":     parser = argparse.ArgumentParser('Scene segmentation training/testing')     # 创建解析器     parser.add_argument('--cfg', type=str, required=True, help='config file')     parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')     # 添加参数     args, opts = parser.parse_known_args()     # CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml  mode=train     # 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析     # args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False)     # opts = ['mode=train']      cfg = EasyConfig()     cfg.load(args.cfg, recursive=True)  # args.cfg = cfs/s3dis/pointnext-s.yaml     cfg.update(opts)       # overwrite the default arguments in yml     # mode = train 更新入 cfg 字典      if cfg.seed is None:         cfg.seed = np.random.randint(1, 10000)      # init distributed env first, since logger depends on the dist info.     cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)     cfg.sync_bn = cfg.world_size > 1  # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行      # init log dir     cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]       # task/dataset name, \eg s3dis, modelnet40_cls     # args.cfg = 'cfgs/s3dis/pointnext-s.yaml'     # args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s'     # args.cfg.split('.')[-2].split('/')[-2] = 's3dis'     cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]       # cfg_basename, \eg pointnext-xl\     # args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s'     tags = [         cfg.task_name,  # task name (the folder of name under ./cfgs         cfg.mode,         cfg.cfg_basename,  # cfg file name         f'ngpus{cfg.world_size}',     ]     # tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1']     opt_list = [] # for checking experiment configs from logging file     for i, opt in enumerate(opts):         if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:             opt_list.append(opt)     cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)     cfg.opts = '-'.join(opt_list)  # 使用'-'作分隔符来进行join      cfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']     if cfg.mode in ['resume', 'val', 'test']:         resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)           # 需要命令行 加 pretrained_path=XXX         cfg.wandb.tags = [cfg.mode]     else:         generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))         cfg.wandb.tags = tags     os.environ["JOB_LOG_DIR"] = cfg.log_dir     cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")     # cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml'     with open(cfg_path, 'w') as f:         yaml.dump(cfg, f, indent=2)  # cfg 写入 f 文件         os.system('cp %s %s' % (args.cfg, cfg.run_dir))         # args.cfg = 'cfgs/s3dis/pointnext-s.yaml'         # cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'     cfg.cfg_path = cfg_path      # wandb config     cfg.wandb.name = cfg.run_name     # cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'      # multi processing.     if cfg.mp:         port = find_free_port()         cfg.dist_url = f"tcp://localhost:{port}"         print('using mp spawn for distributed training')         mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))     else:         main(0, cfg) 

2. 参数加载更新

配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.

class EasyConfig(dict):     def __getattr__(self, key: str) -> Any:         if key not in self:             raise AttributeError(key)         return self[key]      def __setattr__(self, key: str, value: Any) -> None:         self[key] = value      def __delattr__(self, key: str) -> None:         del self[key]      def load(self, fpath: str, *, recursive: bool = False) -> None:         """load cfg from yaml          Args:             fpath (str): path to the yaml file             recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False.         """         if not os.path.exists(fpath):             raise FileNotFoundError(fpath)         fpaths = [fpath]         # 'cfgs/s3dis/pointnext-s.yaml'         if recursive:  # True             extension = os.path.splitext(fpath)[1]   # .yaml             while os.path.dirname(fpath) != fpath:   # 如果 fpath 是文件路径                 fpath = os.path.dirname(fpath)  # 去掉文件名, 返回目录, 每次脱去一级                 fpaths.append(os.path.join(fpath, 'default' + extension))                    #  fpaths =['cfgs/s3dis/pointnext-s.yaml',                  #           'cfgs/s3dis/default.yaml',                  #           'cfgs/default.yaml',                  #           'default.yaml']         for fpath in reversed(fpaths):   # 反转迭代器             if os.path.exists(fpath):                 with open(fpath) as f:                     self.update(yaml.safe_load(f))                        # 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中      def reload(self, fpath: str, *, recursive: bool = False) -> None:         self.clear()         self.load(fpath, recursive=recursive)      # mutimethod makes python supports function overloading     @multimethod     def update(self, other: Dict) -> None:       # .yaml items 转为 dict 变量中的 key:value 对         for key, value in other.items():             if isinstance(value, dict):                 if key not in self or not isinstance(self[key], EasyConfig):                   	# 子条目                     self[key] = EasyConfig()                 # recursively update                 self[key].update(value)             else:                 self[key] = value      @multimethod     def update(self, opts: Union[List, Tuple]) -> None:         index = 0         while index < len(opts):             opt = opts[index]             if opt.startswith('--'):                 opt = opt[2:]             if '=' in opt:                 key, value = opt.split('=', 1)                 index += 1             else:                 key, value = opt, opts[index + 1]                 index += 2             current = self             subkeys = key.split('.')             try:                 value = literal_eval(value)             except:                 pass             for subkey in subkeys[:-1]:                 current = current.setdefault(subkey, EasyConfig())             current[subkeys[-1]] = value      def dict(self) -> Dict[str, Any]:         configs = dict()         for key, value in self.items():             if isinstance(value, EasyConfig):                 value = value.dict()             configs[key] = value         return configs      def hash(self) -> str:         buffer = json.dumps(self.dict(), sort_keys=True)         return hashlib.sha256(buffer.encode()).hexdigest()      def __str__(self) -> str:         texts = []         for key, value in self.items():             if isinstance(value, EasyConfig):                 seperator = '\n'             else:                 seperator = ' '             text = key + ':' + seperator + str(value)             lines = text.split('\n')             for k, line in enumerate(lines[1:]):                 lines[k + 1] = (' ' * 2) + line             texts.extend(lines)         return '\n'.join(texts) 

3. 获得的参数

fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml'] 所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.

dist_url: tcp://localhost:8888 dist_backend: nccl multiprocessing_distributed: False ngpus_per_node: 1 world_size: 1 launcher: mp local_rank: 0 use_gpu: True seed: 3392 epoch: 0 epochs: 100 ignore_index: None val_fn: validate deterministic: False sync_bn: False criterion_args:   NAME: CrossEntropy   label_smoothing: 0.2 use_mask: False grad_norm_clip: 10 layer_decay: 0 step_per_update: 1 start_epoch: 1 sched_on_epoch: True wandb:   use_wandb: False   project: PointNeXt-S3DIS   tags: ['s3dis', 'train', 'pointnext-s', 'ngpus1']   name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu use_amp: False use_voting: False val_freq: 1 resume: False test: False finetune: False mode: train logname: None load_path: None print_freq: 50 save_freq: -1 root_dir: log/s3dis pretrained_path: None datatransforms:   train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize']   val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize']   vote: ['ChromaticDropGPU']   kwargs:     color_drop: 0.2     gravity_dim: 2     scale: [0.9, 1.1]     angle: [0, 0, 1]     jitter_sigma: 0.005     jitter_clip: 0.02 feature_keys: x,heights dataset:   common:     NAME: S3DIS     data_root: data/S3DIS/s3disfull     test_area: 5     voxel_size: 0.04   train:     split: train     voxel_max: 24000     loop: 30     presample: False   val:     split: val     voxel_max: None     presample: True   test:     split: test     voxel_max: None     presample: False num_classes: 13 batch_size: 32 val_batch_size: 1 dataloader:   num_workers: 6 cls_weighed_loss: False optimizer:   NAME: adamw   weight_decay: 0.0001 sched: cosine warmup_epochs: 0 min_lr: 1e-05 lr: 0.01 log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu model:   NAME: BaseSeg   encoder_args:     NAME: PointNextEncoder     blocks: [1, 1, 1, 1, 1]     strides: [1, 4, 4, 4, 4]     sa_layers: 2     sa_use_res: True     width: 32     in_channels: 4     expansion: 4     radius: 0.1     nsample: 32     aggr_args:       feature_type: dp_fj       reduction: max     group_args:       NAME: ballquery       normalize_dp: True     conv_args:       order: conv-norm-act     act_args:       act: relu     norm_args:       norm: bn   decoder_args:     NAME: PointNextDecoder   cls_args:     NAME: SegHead     num_classes: 13     in_channels: None     norm_args:       norm: bn   in_channels: 4 rank: 0 distributed: False mp: False task_name: s3dis cfg_basename: pointnext-s opts: mode=train is_training: True run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml 

III. 总结

1. 结果

examples/segmentation/main.pymain() 函数中建立深度网络模型 (类实现) 的部分代码:

    if cfg.model.get('in_channels', None) is None:         cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4     model = build_model_from_cfg(cfg.model).to(cfg.rank)     model_size = cal_model_parm_nums(model)     logging.info(model)     logging.info('Number of params: %.4f M' % (model_size / 1e6)) 

通过 build_model_from_cfg(cfg.model) 调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg), 获得网络模型结构:

 BaseSeg(   (encoder): PointNextEncoder(     (encoder): Sequential(       (0): Sequential(         (0): SetAbstraction(           (convs): Sequential(             (0): Sequential(               (0): Conv1d(4, 32, kernel_size=(1,), stride=(1,))             )           )         )       )       (1): Sequential(         (0): SetAbstraction(           (skipconv): Sequential(             (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))           )           (act): ReLU(inplace=True)           (convs): Sequential(             (0): Sequential(               (0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)             )           )           (grouper): QueryAndGroup()         )       )       (2): Sequential(         (0): SetAbstraction(           (skipconv): Sequential(             (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))           )           (act): ReLU(inplace=True)           (convs): Sequential(             (0): Sequential(               (0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)             )           )           (grouper): QueryAndGroup()         )       )       (3): Sequential(         (0): SetAbstraction(           (skipconv): Sequential(             (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))           )           (act): ReLU(inplace=True)           (convs): Sequential(             (0): Sequential(               (0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)             )           )           (grouper): QueryAndGroup()         )       )       (4): Sequential(         (0): SetAbstraction(           (skipconv): Sequential(             (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,))           )           (act): ReLU(inplace=True)           (convs): Sequential(             (0): Sequential(               (0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)               (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)             )           )           (grouper): QueryAndGroup()         )       )     )   )   (decoder): PointNextDecoder(     (decoder): Sequential(       (0): Sequential(         (0): FeaturePropogation(           (convs): Sequential(             (0): Sequential(               (0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )           )         )       )       (1): Sequential(         (0): FeaturePropogation(           (convs): Sequential(             (0): Sequential(               (0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )           )         )       )       (2): Sequential(         (0): FeaturePropogation(           (convs): Sequential(             (0): Sequential(               (0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )           )         )       )       (3): Sequential(         (0): FeaturePropogation(           (convs): Sequential(             (0): Sequential(               (0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )             (1): Sequential(               (0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)               (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)               (2): ReLU(inplace=True)             )           )         )       )     )   )   (head): SegHead(     (head): Sequential(       (0): Sequential(         (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)         (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)         (2): ReLU(inplace=True)       )       (1): Dropout(p=0.5, inplace=False)       (2): Sequential(         (0): Conv1d(32, 13, kernel_size=(1,), stride=(1,))       )     )   ) )  

2. Todo

以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.

感谢论文和代码作者开源研究成果 !


广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!