快捷键

自定义损失函数

MMDetection 为用户提供了不同的损失函数。但是默认配置可能不适用于不同的数据集或模型,因此用户可能希望修改特定的损失函数以适应新情况。

本教程首先阐述损失函数的计算流程,然后给出一些关于如何修改每个步骤的说明。修改可以分为微调和加权两种类型。

损失函数的计算流程

给定输入预测和目标,以及权重,损失函数将输入张量映射到最终的损失标量。映射可以分为五个步骤

  1. 设置采样方法以采样正负样本。

  2. 通过损失核函数获取**逐元素**或**逐样本**的损失。

  3. **逐元素**地用权重张量加权损失。

  4. 将损失张量缩减为**标量**。

  5. 用**标量**加权损失。

设置采样方法(步骤 1)

对于某些损失函数,需要采样策略来避免正负样本之间的不平衡。

例如,当在 RPN 头部使用 CrossEntropyLoss 时,我们需要在 train_cfg 中设置 RandomSampler

train_cfg=dict(
    rpn=dict(
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False))

对于其他一些具有正负样本平衡机制的损失函数,例如 Focal Loss、GHMC 和 QualityFocalLoss,采样器不再必要。

微调损失函数

微调损失函数更多地与步骤 2、4、5 相关,大多数修改可以在配置文件中指定。这里我们以 Focal Loss (FL) 为例。以下代码片段分别是 FL 的构造方法和配置文件,它们实际上是一一对应的。

@LOSSES.register_module()
class FocalLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.25,
                 reduction='mean',
                 loss_weight=1.0):
loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0)

微调超参数(步骤 2)

gammabeta 是 Focal Loss 中的两个超参数。假设我们想将 gamma 的值改为 1.5,并将 alpha 改为 0.5,那么我们可以在配置文件中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=1.5,
    alpha=0.5,
    loss_weight=1.0)

微调缩减方式(步骤 3)

FL 的默认缩减方式为 mean。假设我们想将缩减方式从 mean 改为 sum,我们可以在配置文件中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=1.0,
    reduction='sum')

微调损失权重(步骤 5)

这里的损失权重是一个标量,它控制多任务学习中不同损失的权重,例如分类损失和回归损失。假设我们想将分类损失的损失权重改为 0.5,我们可以在配置文件中指定如下:

loss_cls=dict(
    type='FocalLoss',
    use_sigmoid=True,
    gamma=2.0,
    alpha=0.25,
    loss_weight=0.5)

加权损失函数(步骤 3)

加权损失函数是指我们**逐元素**地重新加权损失函数。更具体地说,我们用一个与损失张量具有相同形状的权重张量乘以损失张量。这样,损失的不同项就可以被不同地缩放,这就是所谓的**逐元素**加权。损失权重在不同的模型中是不同的,而且高度依赖于上下文,但总体而言,损失权重有两种类型:label_weights 用于分类损失,bbox_weights 用于 bbox 回归损失。你可以在对应头的 get_target 方法中找到它们。这里我们以 ATSSHead 为例,它继承了 AnchorHead 但覆盖了它的 get_targets 方法,该方法会产生不同的 label_weightsbbox_weights

class ATSSHead(AnchorHead):

    ...

    def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True):