快捷键

使用自定义模型和标准数据集进行训练

在本指南中,您将了解如何在标准数据集上训练、测试和推断您自己的自定义模型。我们以使用 cityscapes 数据集训练自定义 Cascade Mask R-CNN R50 模型为例,演示了整个过程,该过程使用 AugFPN 替换默认的 FPN 作为颈部,并添加 RotateTranslateX 作为训练时自动增强。

基本步骤如下:

  1. 准备标准数据集

  2. 准备您自己的自定义模型

  3. 准备配置文件

  4. 在标准数据集上训练、测试和推断模型。

准备标准数据集

在本指南中,我们以标准的 cityscapes 数据集为例。

建议将数据集根目录符号链接到 $MMDETECTION/data。如果您使用的文件夹结构不同,则可能需要更改配置文件中的相应路径。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── coco
│   │   ├── annotations
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017
│   ├── cityscapes
│   │   ├── annotations
│   │   ├── leftImg8bit
│   │   │   ├── train
│   │   │   ├── val
│   │   ├── gtFine
│   │   │   ├── train
│   │   │   ├── val
│   ├── VOCdevkit
│   │   ├── VOC2007
│   │   ├── VOC2012

或者,您可以通过以下方式设置数据集根目录:

export MMDET_DATASETS=$data_root

我们将用 $MMDET_DATASETS 替换数据集根目录,因此您不必修改配置文件中的相应路径。

cityscapes 注释必须使用 tools/dataset_converters/cityscapes.py 转换为 coco 格式。

pip install cityscapesscripts
python tools/dataset_converters/cityscapes.py ./data/cityscapes --nproc 8 --out-dir ./data/cityscapes/annotations

目前,cityscapes 中的配置文件使用 COCO 预训练权重进行初始化。如果网络不可用或速度慢,您可以提前下载预训练模型,否则会在训练开始时出现错误。

准备您自己的自定义模型

第二步是使用您自己的模块或训练设置。假设我们想要实现一个名为 AugFPN 的新颈部,以替换现有检测器 Cascade Mask R-CNN R50 中的默认 FPN。以下是在 MMDetection 中实现 AugFPN 的步骤。

1. 定义一个新的颈部(例如 AugFPN)

首先创建一个新文件 mmdet/models/necks/augfpn.py

import torch.nn as nn
from mmdet.registry import MODELS


@MODELS.register_module()
class AugFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 导入模块

您可以将以下行添加到 mmdet/models/necks/__init__.py 中,

from .augfpn import AugFPN

或者,您也可以将以下内容添加到配置文件中,避免修改原始代码。

custom_imports = dict(
    imports=['mmdet.models.necks.augfpn'],
    allow_failed_imports=False)

to the config file and avoid modifying the original code.

3. 修改配置文件

neck=dict(
    type='AugFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

有关自定义您自己的模型(例如,实现新的主干、头部、损失等)和运行时训练设置(例如,定义新的优化器,使用梯度裁剪,自定义训练计划和钩子等)的更详细用法,请分别参考指南 自定义模型自定义运行时设置

准备配置文件

第三步是为您的训练设置准备配置文件。假设我们想要将 AugFPNRotateTranslate 增强添加到现有的 Cascade Mask R-CNN R50 中,以训练 cityscapes 数据集,并假设配置文件位于 configs/cityscapes/ 目录下,并命名为 cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes.py,配置文件如下所示。

# The new config inherits the base configs to highlight the necessary modification
_base_ = [
    '../_base_/models/cascade-mask-rcnn_r50_fpn.py',
    '../_base_/datasets/cityscapes_instance.py', '../_base_/default_runtime.py'
]

model = dict(
    # set None to avoid loading ImageNet pre-trained backbone,
    # instead here we set `load_from` to load from COCO pre-trained detectors.
    backbone=dict(init_cfg=None),
    # replace neck from defaultly `FPN` to our new implemented module `AugFPN`
    neck=dict(
        type='AugFPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    # We also need to change the num_classes in head from 80 to 8, to match the
    # cityscapes dataset's annotation. This modification involves `bbox_head` and `mask_head`.
    roi_head=dict(
        bbox_head=[
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                # change the number of classes from defaultly COCO to cityscapes
                num_classes=8,
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.1, 0.1, 0.2, 0.2]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
                               loss_weight=1.0)),
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                # change the number of classes from defaultly COCO to cityscapes
                num_classes=8,
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.05, 0.05, 0.1, 0.1]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
                               loss_weight=1.0)),
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                # change the number of classes from defaultly COCO to cityscapes
                num_classes=8,
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.033, 0.033, 0.067, 0.067]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
        ],
        mask_head=dict(
            type='FCNMaskHead',
            num_convs=4,
            in_channels=256,
            conv_out_channels=256,
            # change the number of classes from default COCO to cityscapes
            num_classes=8,
            loss_mask=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))

# over-write `train_pipeline` for new added `AutoAugment` training setting
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(
        type='AutoAugment',
        policies=[
            [dict(
                 type='Rotate',
                 level=5,
                 img_border_value=(124, 116, 104),
                 prob=0.5)
            ],
            [dict(type='Rotate', level=7, img_border_value=(124, 116, 104)),
             dict(
                 type='TranslateX',
                 level=5,
                 prob=0.5,
                 img_border_value=(124, 116, 104))
            ],
        ]),
    dict(
        type='RandomResize',
        scale=[(2048, 800), (2048, 1024)],
        keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs'),
]

# set batch_size per gpu, and set new training pipeline
train_dataloader = dict(
    batch_size=1,
    num_workers=3,
    # over-write `pipeline` with new training pipeline setting
    dataset=dict(pipeline=train_pipeline))

# Set optimizer
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))

# Set customized learning policy
param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=10,
        by_epoch=True,
        milestones=[8],
        gamma=0.1)
]

# train, val, test loop config
train_cfg = dict(max_epochs=10, val_interval=1)

# We can use the COCO pre-trained Cascade Mask R-CNN R50 model for a more stable performance initialization
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco/cascade_mask_rcnn_r50_fpn_1x_coco_20200203-9d4dcb24.pth'

训练一个新的模型

要使用新的配置文件训练模型,您可以简单地运行以下命令:

python tools/train.py configs/cityscapes/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes.py

有关更详细的用法,请参考 训练指南

测试和推断

要测试训练好的模型,您可以简单地运行以下命令:

python tools/test.py configs/cityscapes/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes.py work_dirs/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes/epoch_10.pth

有关更详细的用法,请参考 测试指南