Title: PointNeXt 源码阅读 (I) —— 注册机制与参数解析
文章目录
前言
学习了部分 PointNeXt 源码, 先记录一下, 以备忘.
本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.
相关注释和调试信息都是基于下面测试 session.
CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \ --cfg cfgs/s3dis/pointnext-s.yaml mode=train
I. 注册机制
所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.
1. 注册类 Registry
注册机制本身是通过注册类 class Registry
实现的, 其中关键方法有:
方法 | 解释 |
---|---|
__init__() | 类初始化, 其中也初始化了注册模块字典 self._module_dict = dict() |
get(self, key) | 实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key] |
register_module(self, name=None, force=False, module=None) | 注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器 |
_register(cls) | 装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class , 然后没有任何其他处理而直接 return 了类定义 cls |
build_from_cfg(cfg, registry, default_args=None) | 从配置字典构建模块/类实例, 实现由字符串生成模块/类实例 |
openpoints/utils/registry.py
中定义了 Registry 类, 添加注释如下.
# Acknowledgement: built upon mmcv import inspect import warnings from functools import partial import copy class Registry: """A registry to map strings to classes. Registered object could be built from registry. Example: >>> MODELS = Registry('models') >>> @MODELS.register_module() >>> class ResNet: >>> pass >>> resnet = MODELS.build(dict(NAME='ResNet')) Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for advanced useage. Args: name (str): Registry name. build_func(func, optional): Build function to construct instance from Registry, func:`build_from_cfg` is used if neither ``parent`` or ``build_func`` is specified. If ``parent`` is specified and ``build_func`` is not given, ``build_func`` will be inherited from ``parent``. Default: None. parent (Registry, optional): Parent registry. The class registered in children registry could be built from parent. Default: None. scope (str, optional): The scope of registry. It is the key to search for children registry. If not specified, scope will be the name of the package where class is defined, e.g. mmdet, mmcls, mmseg. Default: None. """ def __init__(self, name, build_func=None, parent=None, scope=None): self._name = name self._module_dict = dict() self._children = dict() self._scope = self.infer_scope() if scope is None else scope # self._scope = 'openpoints' # self.build_func will be set with the following priority: # 1. build_func # 2. parent.build_func # 3. build_from_cfg if build_func is None: if parent is not None: self.build_func = parent.build_func else: self.build_func = build_from_cfg else: self.build_func = build_func if parent is not None: assert isinstance(parent, Registry) parent._add_children(self) self.parent = parent else: self.parent = None def __len__(self): return len(self._module_dict) def __contains__(self, key): return self.get(key) is not None def __repr__(self): format_str = self.__class__.__name__ + \ f'(name={self._name}, ' \ f'items={self._module_dict})' return format_str @staticmethod def infer_scope(): """Infer the scope of registry. The name of the package where registry is defined will be returned. Example: # in mmdet/models/backbone/resnet.py >>> MODELS = Registry('models') >>> @MODELS.register_module() >>> class ResNet: >>> pass The scope of ``ResNet`` will be ``mmdet``. Returns: scope (str): The inferred scope name. """ # inspect.stack() trace where this function is called, the index-2 # indicates the frame where `infer_scope()` is called filename = inspect.getmodule(inspect.stack()[2][0]).__name__ # filename = 'openpoints.models.build' split_filename = filename.split('.') # ['openpoints', 'models', 'build'] return split_filename[0] # 'openpoints' @staticmethod # 返回函数的静态方法、声明一个静态方法 def split_scope_key(key): """Split scope and key. The first scope will be split from key. Examples: >>> Registry.split_scope_key('mmdet.ResNet') 'mmdet', 'ResNet' >>> Registry.split_scope_key('ResNet') None, 'ResNet' Return: scope (str, None): The first scope. key (str): The remaining key. """ split_index = key.find('.') # 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值 if split_index != -1: return key[:split_index], key[split_index + 1:] else: return None, key @property def name(self): return self._name @property def scope(self): return self._scope @property def module_dict(self): return self._module_dict @property def children(self): return self._children def get(self, key): # 实现从字符串到类的映射 # 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key] """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ scope, real_key = self.split_scope_key(key) # key = BaseSeg; scope = None; real_key = BaseSeg if scope is None or scope == self._scope: # get from self if real_key in self._module_dict: return self._module_dict[real_key] else: # get from self._children if scope in self._children: return self._children[scope].get(real_key) else: # goto root parent = self.parent while parent.parent is not None: parent = parent.parent return parent.get(key) def build(self, *args, **kwargs): return self.build_func(*args, **kwargs, registry=self) def _add_children(self, registry): """Add children for a registry. The ``registry`` will be added as children based on its scope. The parent registry could build objects from children registry. Example: >>> models = Registry('models') >>> mmdet_models = Registry('models', parent=models) >>> @mmdet_models.register_module() >>> class ResNet: >>> pass >>> resnet = models.build(dict(NAME='mmdet.ResNet')) """ assert isinstance(registry, Registry) assert registry.scope is not None assert registry.scope not in self.children, \ f'scope {registry.scope} exists in {self.name} registry' self.children[registry.scope] = registry def _register_module(self, module_class, module_name=None, force=False): if not inspect.isclass(module_class): raise TypeError('module must be a class, ' f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__ if isinstance(module_name, str): module_name = [module_name] for name in module_name: if not force and name in self._module_dict: raise KeyError(f'{name} is already registered ' f'in {self.name}') self._module_dict[name] = module_class def deprecated_register_module(self, cls=None, force=False): warnings.warn( 'The old API of register_module(module, force=False) ' 'is deprecated and will be removed, please use the new API ' 'register_module(name=None, force=False, module=None) instead.') if cls is None: return partial(self.deprecated_register_module, force=force) self._register_module(cls, force=force) return cls def register_module(self, name=None, force=False, module=None): # 装饰器 """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module() >>> class ResNet: >>> pass >>> backbones = Registry('backbone') >>> @backbones.register_module(name='mnet') >>> class MobileNet: >>> pass >>> backbones = Registry('backbone') >>> class ResNet: >>> pass >>> backbones.register_module(ResNet) Args: name (str | None): The module name to be registered. If not specified, the class name will be used. force (bool, optional): Whether to override an existing class with the same name. Default: False. module (type): Module class to be registered. """ if not isinstance(force, bool): raise TypeError(f'force must be a boolean, but got {type(force)}') # NOTE: This is a walkaround to be compatible with the old api, # while it may introduce unexpected bugs. if isinstance(name, type): return self.deprecated_register_module(name, force=force) # raise the error ahead of time if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)): raise TypeError( 'name must be either of None, an instance of str or a sequence' f' of str, but got {type(name)}') # use it as a normal method: x.register_module(module=SomeClass) # 正常调用 reister_module, 不是装饰情况 if module is not None: self._register_module( module_class=module, module_name=name, force=force) return module # use it as a decorator: @x.register_module() # 这是装饰器的包装函数 wrapper # 装饰情况下的调用, cls 就是传递进来的需要被装饰的类 def _register(cls): self._register_module( module_class=cls, module_name=name, force=force) return cls return _register # 装饰器返回这个包装函数 def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict. Args: cfg (edict): Config dict. It should at least contain the key "NAME". registry (:obj:`Registry`): The registry to search the type from. Returns: object: The constructed object. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'NAME' not in cfg: if default_args is None or 'NAME' not in default_args: raise KeyError( '`cfg` or `default_args` must contain the key "NAME", ' f'but got {cfg}\n{default_args}') if not isinstance(registry, Registry): raise TypeError('registry must be an mmcv.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') # if default_args is not None: # cfg = config.merge_new_config(cfg, default_args) obj_type = cfg.get('NAME') # 'BaseSeg' if isinstance(obj_type, str): obj_cls = registry.get(obj_type) # <class 'openpoints.models.segmentation.base_seg.BaseSeg'> # 按照名字字符串 从 self._module_dict 找出对应的 类/模块 # 实现从字符串到类的映射 if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') try: obj_cfg = copy.deepcopy(cfg) if default_args is not None: obj_cfg.update(default_args) obj_cfg.pop('NAME') # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项 # 'NAME' 已完成对类 obj_cls 的映射 return obj_cls(**obj_cfg) # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg) # 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg)) # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了 # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg) except Exception as e: # Normal TypeError does not print class name. raise type(e)(f'{obj_cls.__name__}: {e}')
2. 类的注册
首先在 openpoints/models/build.py
中声明和定义了全局的注册类对象 MODELS, 称为注册器.
通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__
/main()
) 就建立.
from openpoints.utils import registry MODELS = registry.Registry('models') # 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量 # 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELS def build_model_from_cfg(cfg, **kwargs): """ Build a model, defined by `NAME`. Args: cfg (eDICT): Returns: Model: a constructed model specified by NAME. """ return MODELS.build(cfg, **kwargs)
也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models
下面在类定义前装饰了 @MODELS.register_module()
的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.
例如下面的 BaseSeg
类也会先注册到 MODELS._module_dict
中.
""" Author: PointNeXt """ import copy from typing import List import torch import torch.nn as nn import logging from ...utils import get_missing_parameters_message, get_unexpected_parameters_message from ..build import MODELS, build_model_from_cfg from ..layers import create_linearblock, create_convblock1d # 为类 BaseSeg 加了装饰器 MODELS.register_module # 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg()) # 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立. # 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类 @MODELS.register_module() class BaseSeg(nn.Module): def __init__(self, encoder_args=None, decoder_args=None, cls_args=None, **kwargs): super().__init__()
调试过程中, 跟踪查看 MODELS._module_dict
可以发现已经注册了好多类.
MODELS._module_dict = { 'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>, 'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>, 'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>, 'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>, 'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>, 'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>, 'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>, 'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>, 'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>, 'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>, 'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>, 'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>, 'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>, 'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>, 'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>, 'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>, 'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>, 'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>, 'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>, 'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>, 'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>, 'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>, 'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>, 'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>, 'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>, 'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>, 'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>, 'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>, 'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>, 'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>, 'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>, 'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>, 'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>, 'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>, 'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>, 'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>, 'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'> }
3. 注册应用
有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml
文件配置实现类实例的创建.
初略时序如下图所示:
其中由 .yaml
文件读取获得的配置字典变量 cfg
中存在 NAME 条目, 通过 registry.get(*)
就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).
细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:
def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict. Args: cfg (edict): Config dict. It should at least contain the key "NAME". registry (:obj:`Registry`): The registry to search the type from. Returns: object: The constructed object. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'NAME' not in cfg: if default_args is None or 'NAME' not in default_args: raise KeyError( '`cfg` or `default_args` must contain the key "NAME", ' f'but got {cfg}\n{default_args}') if not isinstance(registry, Registry): raise TypeError('registry must be an mmcv.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') # if default_args is not None: # cfg = config.merge_new_config(cfg, default_args) obj_type = cfg.get('NAME') # 'BaseSeg' if isinstance(obj_type, str): obj_cls = registry.get(obj_type) # <class 'openpoints.models.segmentation.base_seg.BaseSeg'> # 按照名字字符串 从 self._module_dict 找出对应的 类/模块 # 实现从字符串到类的映射 if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') try: obj_cfg = copy.deepcopy(cfg) if default_args is not None: obj_cfg.update(default_args) obj_cfg.pop('NAME') # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项 # 'NAME' 已完成对类 obj_cls 的映射 return obj_cls(**obj_cfg) # 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg) # 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg)) # 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了 # 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
II. 参数解析
注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml
文件中的配置读入程序中.
1. 命令行解析
主程序部分先要将相关的 .yaml
文件读入并更新到 cfg 字典变量中, 注释如下.
if __name__ == "__main__": parser = argparse.ArgumentParser('Scene segmentation training/testing') # 创建解析器 parser.add_argument('--cfg', type=str, required=True, help='config file') parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed') # 添加参数 args, opts = parser.parse_known_args() # CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml mode=train # 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析 # args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False) # opts = ['mode=train'] cfg = EasyConfig() cfg.load(args.cfg, recursive=True) # args.cfg = cfs/s3dis/pointnext-s.yaml cfg.update(opts) # overwrite the default arguments in yml # mode = train 更新入 cfg 字典 if cfg.seed is None: cfg.seed = np.random.randint(1, 10000) # init distributed env first, since logger depends on the dist info. cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg) cfg.sync_bn = cfg.world_size > 1 # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行 # init log dir cfg.task_name = args.cfg.split('.')[-2].split('/')[-2] # task/dataset name, \eg s3dis, modelnet40_cls # args.cfg = 'cfgs/s3dis/pointnext-s.yaml' # args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s' # args.cfg.split('.')[-2].split('/')[-2] = 's3dis' cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1] # cfg_basename, \eg pointnext-xl\ # args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s' tags = [ cfg.task_name, # task name (the folder of name under ./cfgs cfg.mode, cfg.cfg_basename, # cfg file name f'ngpus{cfg.world_size}', ] # tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1'] opt_list = [] # for checking experiment configs from logging file for i, opt in enumerate(opts): if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt: opt_list.append(opt) cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name) cfg.opts = '-'.join(opt_list) # 使用'-'作分隔符来进行join cfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation'] if cfg.mode in ['resume', 'val', 'test']: resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path) # 需要命令行 加 pretrained_path=XXX cfg.wandb.tags = [cfg.mode] else: generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None)) cfg.wandb.tags = tags os.environ["JOB_LOG_DIR"] = cfg.log_dir cfg_path = os.path.join(cfg.run_dir, "cfg.yaml") # cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml' with open(cfg_path, 'w') as f: yaml.dump(cfg, f, indent=2) # cfg 写入 f 文件 os.system('cp %s %s' % (args.cfg, cfg.run_dir)) # args.cfg = 'cfgs/s3dis/pointnext-s.yaml' # cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC' cfg.cfg_path = cfg_path # wandb config cfg.wandb.name = cfg.run_name # cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC' # multi processing. if cfg.mp: port = find_free_port() cfg.dist_url = f"tcp://localhost:{port}" print('using mp spawn for distributed training') mp.spawn(main, nprocs=cfg.world_size, args=(cfg,)) else: main(0, cfg)
2. 参数加载更新
配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.
class EasyConfig(dict): def __getattr__(self, key: str) -> Any: if key not in self: raise AttributeError(key) return self[key] def __setattr__(self, key: str, value: Any) -> None: self[key] = value def __delattr__(self, key: str) -> None: del self[key] def load(self, fpath: str, *, recursive: bool = False) -> None: """load cfg from yaml Args: fpath (str): path to the yaml file recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False. """ if not os.path.exists(fpath): raise FileNotFoundError(fpath) fpaths = [fpath] # 'cfgs/s3dis/pointnext-s.yaml' if recursive: # True extension = os.path.splitext(fpath)[1] # .yaml while os.path.dirname(fpath) != fpath: # 如果 fpath 是文件路径 fpath = os.path.dirname(fpath) # 去掉文件名, 返回目录, 每次脱去一级 fpaths.append(os.path.join(fpath, 'default' + extension)) # fpaths =['cfgs/s3dis/pointnext-s.yaml', # 'cfgs/s3dis/default.yaml', # 'cfgs/default.yaml', # 'default.yaml'] for fpath in reversed(fpaths): # 反转迭代器 if os.path.exists(fpath): with open(fpath) as f: self.update(yaml.safe_load(f)) # 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中 def reload(self, fpath: str, *, recursive: bool = False) -> None: self.clear() self.load(fpath, recursive=recursive) # mutimethod makes python supports function overloading @multimethod def update(self, other: Dict) -> None: # .yaml items 转为 dict 变量中的 key:value 对 for key, value in other.items(): if isinstance(value, dict): if key not in self or not isinstance(self[key], EasyConfig): # 子条目 self[key] = EasyConfig() # recursively update self[key].update(value) else: self[key] = value @multimethod def update(self, opts: Union[List, Tuple]) -> None: index = 0 while index < len(opts): opt = opts[index] if opt.startswith('--'): opt = opt[2:] if '=' in opt: key, value = opt.split('=', 1) index += 1 else: key, value = opt, opts[index + 1] index += 2 current = self subkeys = key.split('.') try: value = literal_eval(value) except: pass for subkey in subkeys[:-1]: current = current.setdefault(subkey, EasyConfig()) current[subkeys[-1]] = value def dict(self) -> Dict[str, Any]: configs = dict() for key, value in self.items(): if isinstance(value, EasyConfig): value = value.dict() configs[key] = value return configs def hash(self) -> str: buffer = json.dumps(self.dict(), sort_keys=True) return hashlib.sha256(buffer.encode()).hexdigest() def __str__(self) -> str: texts = [] for key, value in self.items(): if isinstance(value, EasyConfig): seperator = '\n' else: seperator = ' ' text = key + ':' + seperator + str(value) lines = text.split('\n') for k, line in enumerate(lines[1:]): lines[k + 1] = (' ' * 2) + line texts.extend(lines) return '\n'.join(texts)
3. 获得的参数
将 fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml']
所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.
dist_url: tcp://localhost:8888 dist_backend: nccl multiprocessing_distributed: False ngpus_per_node: 1 world_size: 1 launcher: mp local_rank: 0 use_gpu: True seed: 3392 epoch: 0 epochs: 100 ignore_index: None val_fn: validate deterministic: False sync_bn: False criterion_args: NAME: CrossEntropy label_smoothing: 0.2 use_mask: False grad_norm_clip: 10 layer_decay: 0 step_per_update: 1 start_epoch: 1 sched_on_epoch: True wandb: use_wandb: False project: PointNeXt-S3DIS tags: ['s3dis', 'train', 'pointnext-s', 'ngpus1'] name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu use_amp: False use_voting: False val_freq: 1 resume: False test: False finetune: False mode: train logname: None load_path: None print_freq: 50 save_freq: -1 root_dir: log/s3dis pretrained_path: None datatransforms: train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize'] val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize'] vote: ['ChromaticDropGPU'] kwargs: color_drop: 0.2 gravity_dim: 2 scale: [0.9, 1.1] angle: [0, 0, 1] jitter_sigma: 0.005 jitter_clip: 0.02 feature_keys: x,heights dataset: common: NAME: S3DIS data_root: data/S3DIS/s3disfull test_area: 5 voxel_size: 0.04 train: split: train voxel_max: 24000 loop: 30 presample: False val: split: val voxel_max: None presample: True test: split: test voxel_max: None presample: False num_classes: 13 batch_size: 32 val_batch_size: 1 dataloader: num_workers: 6 cls_weighed_loss: False optimizer: NAME: adamw weight_decay: 0.0001 sched: cosine warmup_epochs: 0 min_lr: 1e-05 lr: 0.01 log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu model: NAME: BaseSeg encoder_args: NAME: PointNextEncoder blocks: [1, 1, 1, 1, 1] strides: [1, 4, 4, 4, 4] sa_layers: 2 sa_use_res: True width: 32 in_channels: 4 expansion: 4 radius: 0.1 nsample: 32 aggr_args: feature_type: dp_fj reduction: max group_args: NAME: ballquery normalize_dp: True conv_args: order: conv-norm-act act_args: act: relu norm_args: norm: bn decoder_args: NAME: PointNextDecoder cls_args: NAME: SegHead num_classes: 13 in_channels: None norm_args: norm: bn in_channels: 4 rank: 0 distributed: False mp: False task_name: s3dis cfg_basename: pointnext-s opts: mode=train is_training: True run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml
III. 总结
1. 结果
在 examples/segmentation/main.py
的 main()
函数中建立深度网络模型 (类实现) 的部分代码:
if cfg.model.get('in_channels', None) is None: cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4 model = build_model_from_cfg(cfg.model).to(cfg.rank) model_size = cal_model_parm_nums(model) logging.info(model) logging.info('Number of params: %.4f M' % (model_size / 1e6))
通过 build_model_from_cfg(cfg.model)
调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
, 获得网络模型结构:
BaseSeg( (encoder): PointNextEncoder( (encoder): Sequential( (0): Sequential( (0): SetAbstraction( (convs): Sequential( (0): Sequential( (0): Conv1d(4, 32, kernel_size=(1,), stride=(1,)) ) ) ) ) (1): Sequential( (0): SetAbstraction( (skipconv): Sequential( (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,)) ) (act): ReLU(inplace=True) (convs): Sequential( (0): Sequential( (0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (grouper): QueryAndGroup() ) ) (2): Sequential( (0): SetAbstraction( (skipconv): Sequential( (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,)) ) (act): ReLU(inplace=True) (convs): Sequential( (0): Sequential( (0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (grouper): QueryAndGroup() ) ) (3): Sequential( (0): SetAbstraction( (skipconv): Sequential( (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,)) ) (act): ReLU(inplace=True) (convs): Sequential( (0): Sequential( (0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (grouper): QueryAndGroup() ) ) (4): Sequential( (0): SetAbstraction( (skipconv): Sequential( (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,)) ) (act): ReLU(inplace=True) (convs): Sequential( (0): Sequential( (0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (grouper): QueryAndGroup() ) ) ) ) (decoder): PointNextDecoder( (decoder): Sequential( (0): Sequential( (0): FeaturePropogation( (convs): Sequential( (0): Sequential( (0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) ) ) ) (1): Sequential( (0): FeaturePropogation( (convs): Sequential( (0): Sequential( (0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) ) ) ) (2): Sequential( (0): FeaturePropogation( (convs): Sequential( (0): Sequential( (0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) ) ) ) (3): Sequential( (0): FeaturePropogation( (convs): Sequential( (0): Sequential( (0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Sequential( (0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) ) ) ) ) ) (head): SegHead( (head): Sequential( (0): Sequential( (0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False) (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Dropout(p=0.5, inplace=False) (2): Sequential( (0): Conv1d(32, 13, kernel_size=(1,), stride=(1,)) ) ) ) )
2. Todo
以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.
感谢论文和代码作者开源研究成果 !