You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssd300_coco.py 2.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. _base_ = [
  2. '../_base_/models/ssd300.py', '../_base_/datasets/coco_detection.py',
  3. '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
  4. ]
  5. # dataset settings
  6. dataset_type = 'CocoDataset'
  7. data_root = 'data/coco/'
  8. img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True)
  9. train_pipeline = [
  10. dict(type='LoadImageFromFile', to_float32=True),
  11. dict(type='LoadAnnotations', with_bbox=True),
  12. dict(
  13. type='PhotoMetricDistortion',
  14. brightness_delta=32,
  15. contrast_range=(0.5, 1.5),
  16. saturation_range=(0.5, 1.5),
  17. hue_delta=18),
  18. dict(
  19. type='Expand',
  20. mean=img_norm_cfg['mean'],
  21. to_rgb=img_norm_cfg['to_rgb'],
  22. ratio_range=(1, 4)),
  23. dict(
  24. type='MinIoURandomCrop',
  25. min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
  26. min_crop_size=0.3),
  27. dict(type='Resize', img_scale=(300, 300), keep_ratio=False),
  28. dict(type='Normalize', **img_norm_cfg),
  29. dict(type='RandomFlip', flip_ratio=0.5),
  30. dict(type='DefaultFormatBundle'),
  31. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
  32. ]
  33. test_pipeline = [
  34. dict(type='LoadImageFromFile'),
  35. dict(
  36. type='MultiScaleFlipAug',
  37. img_scale=(300, 300),
  38. flip=False,
  39. transforms=[
  40. dict(type='Resize', keep_ratio=False),
  41. dict(type='Normalize', **img_norm_cfg),
  42. dict(type='ImageToTensor', keys=['img']),
  43. dict(type='Collect', keys=['img']),
  44. ])
  45. ]
  46. data = dict(
  47. samples_per_gpu=8,
  48. workers_per_gpu=3,
  49. train=dict(
  50. _delete_=True,
  51. type='RepeatDataset',
  52. times=5,
  53. dataset=dict(
  54. type=dataset_type,
  55. ann_file=data_root + 'annotations/instances_train2017.json',
  56. img_prefix=data_root + 'train2017/',
  57. pipeline=train_pipeline)),
  58. val=dict(pipeline=test_pipeline),
  59. test=dict(pipeline=test_pipeline))
  60. # optimizer
  61. optimizer = dict(type='SGD', lr=2e-3, momentum=0.9, weight_decay=5e-4)
  62. optimizer_config = dict(_delete_=True)
  63. custom_hooks = [
  64. dict(type='NumClassCheckHook'),
  65. dict(type='CheckInvalidLossHook', interval=50, priority='VERY_LOW')
  66. ]

No Description

Contributors (2)