You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_pipeline.md 5.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # 教程 3: 自定义数据预处理流程
  2. ## 数据流程的设计
  3. 按照惯例,我们使用 `Dataset` 和 `DataLoader` 进行多进程的数据加载。`Dataset` 返回字典类型的数据,数据内容为模型 `forward` 方法的各个参数。由于在目标检测中,输入的图像数据具有不同的大小,我们在 `MMCV` 里引入一个新的 `DataContainer` 类去收集和分发不同大小的输入数据。更多细节请参考[这里](https://github.com/open-mmlab/mmcv/blob/master/mmcv/parallel/data_container.py)。
  4. 数据的准备流程和数据集是解耦的。通常一个数据集定义了如何处理标注数据(annotations)信息,而一个数据流程定义了准备一个数据字典的所有步骤。一个流程包括一系列的操作,每个操作都把一个字典作为输入,然后再输出一个新的字典给下一个变换操作。
  5. 我们在下图展示了一个经典的数据处理流程。蓝色块是数据处理操作,随着数据流程的处理,每个操作都可以在结果字典中加入新的键(标记为绿色)或更新现有的键(标记为橙色)。
  6. ![pipeline figure](../../resources/data_pipeline.png)
  7. 这些操作可以分为数据加载(data loading)、预处理(pre-processing)、格式变化(formatting)和测试时数据增强(test-time augmentation)。
  8. 下面的例子是 `Faster R-CNN` 的一个流程:
  9. ```python
  10. img_norm_cfg = dict(
  11. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  12. train_pipeline = [
  13. dict(type='LoadImageFromFile'),
  14. dict(type='LoadAnnotations', with_bbox=True),
  15. dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
  16. dict(type='RandomFlip', flip_ratio=0.5),
  17. dict(type='Normalize', **img_norm_cfg),
  18. dict(type='Pad', size_divisor=32),
  19. dict(type='DefaultFormatBundle'),
  20. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
  21. ]
  22. test_pipeline = [
  23. dict(type='LoadImageFromFile'),
  24. dict(
  25. type='MultiScaleFlipAug',
  26. img_scale=(1333, 800),
  27. flip=False,
  28. transforms=[
  29. dict(type='Resize', keep_ratio=True),
  30. dict(type='RandomFlip'),
  31. dict(type='Normalize', **img_norm_cfg),
  32. dict(type='Pad', size_divisor=32),
  33. dict(type='ImageToTensor', keys=['img']),
  34. dict(type='Collect', keys=['img']),
  35. ])
  36. ]
  37. ```
  38. 对于每个操作,我们列出它添加、更新、移除的相关字典域 (dict fields):
  39. ### 数据加载 Data loading
  40. `LoadImageFromFile`
  41. - 增加:img, img_shape, ori_shape
  42. `LoadAnnotations`
  43. - 增加:gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg, bbox_fields, mask_fields
  44. `LoadProposals`
  45. - 增加:proposals
  46. ### 预处理 Pre-processing
  47. `Resize`
  48. - 增加:scale, scale_idx, pad_shape, scale_factor, keep_ratio
  49. - 更新:img, img_shape, *bbox_fields, *mask_fields, *seg_fields
  50. `RandomFlip`
  51. - 增加:flip
  52. - 更新:img, *bbox_fields, *mask_fields, *seg_fields
  53. `Pad`
  54. - 增加:pad_fixed_size, pad_size_divisor
  55. - 更新:img, pad_shape, *mask_fields, *seg_fields
  56. `RandomCrop`
  57. - 更新:img, pad_shape, gt_bboxes, gt_labels, gt_masks, *bbox_fields
  58. `Normalize`
  59. - 增加:img_norm_cfg
  60. - 更新:img
  61. `SegRescale`
  62. - 更新:gt_semantic_seg
  63. `PhotoMetricDistortion`
  64. - 更新:img
  65. `Expand`
  66. - 更新:img, gt_bboxes
  67. `MinIoURandomCrop`
  68. - 更新:img, gt_bboxes, gt_labels
  69. `Corrupt`
  70. - 更新:img
  71. ### 格式 Formatting
  72. `ToTensor`
  73. - 更新:由 `keys` 指定
  74. `ImageToTensor`
  75. - 更新:由 `keys` 指定
  76. `Transpose`
  77. - 更新:由 `keys` 指定
  78. `ToDataContainer`
  79. - 更新:由 `keys` 指定
  80. `DefaultFormatBundle`
  81. - 更新:img, proposals, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg
  82. `Collect`
  83. - 增加:img_metas(img_metas 的键(key)被 `meta_keys` 指定)
  84. - 移除:除了 `keys` 指定的键(key)之外的所有其他的键(key)
  85. ### 测试时数据增强 Test time augmentation
  86. `MultiScaleFlipAug`
  87. ## 拓展和使用自定义的流程
  88. 1. 在任意文件里写一个新的流程,例如在 `my_pipeline.py`,它以一个字典作为输入并且输出一个字典:
  89. ```python
  90. import random
  91. from mmdet.datasets import PIPELINES
  92. @PIPELINES.register_module()
  93. class MyTransform:
  94. """Add your transform
  95. Args:
  96. p (float): Probability of shifts. Default 0.5.
  97. """
  98. def __init__(self, p=0.5):
  99. self.p = p
  100. def __call__(self, results):
  101. if random.random() > self.p:
  102. results['dummy'] = True
  103. return results
  104. ```
  105. 2. 在配置文件里调用并使用你写的数据处理流程,需要确保你的训练脚本能够正确导入新增模块:
  106. ```python
  107. custom_imports = dict(imports=['path.to.my_pipeline'], allow_failed_imports=False)
  108. img_norm_cfg = dict(
  109. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  110. train_pipeline = [
  111. dict(type='LoadImageFromFile'),
  112. dict(type='LoadAnnotations', with_bbox=True),
  113. dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
  114. dict(type='RandomFlip', flip_ratio=0.5),
  115. dict(type='Normalize', **img_norm_cfg),
  116. dict(type='Pad', size_divisor=32),
  117. dict(type='MyTransform', p=0.2),
  118. dict(type='DefaultFormatBundle'),
  119. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
  120. ]
  121. ```
  122. 3. 可视化数据增强处理流程的结果
  123. 如果想要可视化数据增强处理流程的结果,可以使用 `tools/misc/browse_dataset.py` 直观
  124. 地浏览检测数据集(图像和标注信息),或将图像保存到指定目录。
  125. 使用方法请参考[日志分析](../useful_tools.md)

No Description

Contributors (1)