快捷键

有用 Hook

MMDetection 和 MMEngine 为用户提供了各种有用的 Hook,包括日志 Hook、NumClassCheckHook 等。本教程介绍了 MMDetection 中实现的 Hook 的功能和用法。有关在 MMEngine 中使用 Hook 的信息,请阅读 MMEngine 中的 API 文档.

CheckInvalidLossHook

NumClassCheckHook

MemoryProfilerHook

内存分析器 Hook 记录内存信息,包括虚拟内存、交换内存以及当前进程的内存。此 Hook 有助于掌握系统的内存使用情况并发现潜在的内存泄漏错误。要使用此 Hook,用户应该先通过 pip install memory_profiler psutil 安装 memory_profilerpsutil

用法

要使用此 Hook,用户应该将以下代码添加到配置文件中。

custom_hooks = [
    dict(type='MemoryProfilerHook', interval=50)
]

结果

在训练期间,您可以在日志中看到由 MemoryProfilerHook 记录的消息,如下所示。

The system has 250 GB (246360 MB + 9407 MB) of memory and 8 GB (5740 MB + 2452 MB) of swap memory in total. Currently 9407 MB (4.4%) of memory and 5740 MB (29.9%) of swap memory were consumed. And the current training process consumed 5434 MB of memory.
2022-04-21 08:49:56,881 - mmengine - INFO - Memory information available_memory: 246360 MB, used_memory: 9407 MB, memory_utilization: 4.4 %, available_swap_memory: 5740 MB, used_swap_memory: 2452 MB, swap_memory_utilization: 29.9 %, current_process_memory: 5434 MB

SetEpochInfoHook

SyncNormHook

SyncRandomSizeHook

YOLOXLrUpdaterHook

YOLOXModeSwitchHook

如何实现自定义 Hook

通常,从模型训练开始到结束,Hook 可以插入 20 个点。用户可以实现自定义 Hook 并将它们插入训练过程中的不同点,以完成他们想要的操作。

  • 全局点:before_runafter_run

  • 训练中的点:before_trainbefore_train_epochbefore_train_iterafter_train_iterafter_train_epochafter_train

  • 验证中的点:before_valbefore_val_epochbefore_val_iterafter_val_iterafter_val_epochafter_val

  • 测试中的点:before_testbefore_test_epochbefore_test_iterafter_test_iterafter_test_epochafter_test

  • 其他点:before_save_checkpointafter_save_checkpoint

例如,用户可以实现一个 Hook 来检查损失并在损失变为 NaN 时终止训练。要实现这一点,需要三个步骤

  1. 实现一个继承 MMEngine 中 Hook 类的新的 Hook,并实现 after_train_iter 方法,该方法检查每 n 次训练迭代后损失是否变为 NaN。

  2. 实现的 Hook 应该通过 @HOOKS.register_module() 注册在 HOOKS 中,如以下代码所示。

  3. 在配置文件中添加 custom_hooks = [dict(type='MemoryProfilerHook', interval=50)]

from typing import Optional

import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner

from mmdet.registry import HOOKS


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Default: 50.
    """

    def __init__(self, interval: int = 50) -> None:
        self.interval = interval

    def after_train_iter(self,
                         runner: Runner,
                         batch_idx: int,
                         data_batch: Optional[dict] = None,
                         outputs: Optional[dict] = None) -> None:
        """Regularly check whether the loss is valid every n iterations.

        Args:
            runner (:obj:`Runner`): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict, Optional): Data from dataloader.
                Defaults to None.
            outputs (dict, Optional): Outputs from model. Defaults to None.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']), \
                runner.logger.info('loss become infinite or NaN!')

请阅读 自定义运行时,了解有关实现自定义 Hook 的更多信息。