有用 Hook¶
MMDetection 和 MMEngine 为用户提供了各种有用的 Hook,包括日志 Hook、NumClassCheckHook
等。本教程介绍了 MMDetection 中实现的 Hook 的功能和用法。有关在 MMEngine 中使用 Hook 的信息,请阅读 MMEngine 中的 API 文档.
CheckInvalidLossHook¶
NumClassCheckHook¶
MemoryProfilerHook¶
内存分析器 Hook 记录内存信息,包括虚拟内存、交换内存以及当前进程的内存。此 Hook 有助于掌握系统的内存使用情况并发现潜在的内存泄漏错误。要使用此 Hook,用户应该先通过 pip install memory_profiler psutil
安装 memory_profiler
和 psutil
。
结果¶
在训练期间,您可以在日志中看到由 MemoryProfilerHook
记录的消息,如下所示。
The system has 250 GB (246360 MB + 9407 MB) of memory and 8 GB (5740 MB + 2452 MB) of swap memory in total. Currently 9407 MB (4.4%) of memory and 5740 MB (29.9%) of swap memory were consumed. And the current training process consumed 5434 MB of memory.
2022-04-21 08:49:56,881 - mmengine - INFO - Memory information available_memory: 246360 MB, used_memory: 9407 MB, memory_utilization: 4.4 %, available_swap_memory: 5740 MB, used_swap_memory: 2452 MB, swap_memory_utilization: 29.9 %, current_process_memory: 5434 MB
SetEpochInfoHook¶
SyncNormHook¶
SyncRandomSizeHook¶
YOLOXLrUpdaterHook¶
YOLOXModeSwitchHook¶
如何实现自定义 Hook¶
通常,从模型训练开始到结束,Hook 可以插入 20 个点。用户可以实现自定义 Hook 并将它们插入训练过程中的不同点,以完成他们想要的操作。
全局点:
before_run
、after_run
训练中的点:
before_train
、before_train_epoch
、before_train_iter
、after_train_iter
、after_train_epoch
、after_train
验证中的点:
before_val
、before_val_epoch
、before_val_iter
、after_val_iter
、after_val_epoch
、after_val
测试中的点:
before_test
、before_test_epoch
、before_test_iter
、after_test_iter
、after_test_epoch
、after_test
其他点:
before_save_checkpoint
、after_save_checkpoint
例如,用户可以实现一个 Hook 来检查损失并在损失变为 NaN 时终止训练。要实现这一点,需要三个步骤
实现一个继承 MMEngine 中
Hook
类的新的 Hook,并实现after_train_iter
方法,该方法检查每n
次训练迭代后损失是否变为 NaN。实现的 Hook 应该通过
@HOOKS.register_module()
注册在HOOKS
中,如以下代码所示。在配置文件中添加
custom_hooks = [dict(type='MemoryProfilerHook', interval=50)]
。
from typing import Optional
import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval: int = 50) -> None:
self.interval = interval
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[dict] = None) -> None:
"""Regularly check whether the loss is valid every n iterations.
Args:
runner (:obj:`Runner`): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict, Optional): Data from dataloader.
Defaults to None.
outputs (dict, Optional): Outputs from model. Defaults to None.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']), \
runner.logger.info('loss become infinite or NaN!')
请阅读 自定义运行时,了解有关实现自定义 Hook 的更多信息。