# Copyright (c) OpenMMLab. All rights reserved. import bisect import math from collections import defaultdict from unittest.mock import MagicMock import numpy as np from mmdet.datasets import (ClassBalancedDataset, ConcatDataset, CustomDataset, MultiImageMixDataset, RepeatDataset) def test_dataset_wrapper(): CustomDataset.load_annotations = MagicMock() CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx) dataset_a = CustomDataset( ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='') len_a = 10 cat_ids_list_a = [ np.random.randint(0, 80, num).tolist() for num in np.random.randint(1, 20, len_a) ] dataset_a.data_infos = MagicMock() dataset_a.data_infos.__len__.return_value = len_a dataset_a.get_cat_ids = MagicMock( side_effect=lambda idx: cat_ids_list_a[idx]) dataset_b = CustomDataset( ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='') len_b = 20 cat_ids_list_b = [ np.random.randint(0, 80, num).tolist() for num in np.random.randint(1, 20, len_b) ] dataset_b.data_infos = MagicMock() dataset_b.data_infos.__len__.return_value = len_b dataset_b.get_cat_ids = MagicMock( side_effect=lambda idx: cat_ids_list_b[idx]) concat_dataset = ConcatDataset([dataset_a, dataset_b]) assert concat_dataset[5] == 5 assert concat_dataset[25] == 15 assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5] assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15] assert len(concat_dataset) == len(dataset_a) + len(dataset_b) repeat_dataset = RepeatDataset(dataset_a, 10) assert repeat_dataset[5] == 5 assert repeat_dataset[15] == 5 assert repeat_dataset[27] == 7 assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5] assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5] assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7] assert len(repeat_dataset) == 10 * len(dataset_a) category_freq = defaultdict(int) for cat_ids in cat_ids_list_a: cat_ids = set(cat_ids) for cat_id in cat_ids: category_freq[cat_id] += 1 for k, v in category_freq.items(): category_freq[k] = v / len(cat_ids_list_a) mean_freq = np.mean(list(category_freq.values())) repeat_thr = mean_freq category_repeat = { cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) for cat_id, cat_freq in category_freq.items() } repeat_factors = [] for cat_ids in cat_ids_list_a: cat_ids = set(cat_ids) repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids}) repeat_factors.append(math.ceil(repeat_factor)) repeat_factors_cumsum = np.cumsum(repeat_factors) repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr) assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1] for idx in np.random.randint(0, len(repeat_factor_dataset), 3): assert repeat_factor_dataset[idx] == bisect.bisect_right( repeat_factors_cumsum, idx) img_scale = (60, 60) dynamic_scale = (80, 80) pipeline = [ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), dict( type='RandomAffine', scaling_ratio_range=(0.1, 2), border=(-img_scale[0] // 2, -img_scale[1] // 2)), dict( type='MixUp', img_scale=img_scale, ratio_range=(0.8, 1.6), pad_val=114.0), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Resize', keep_ratio=True), dict(type='Pad', pad_to_square=True, pad_val=114.0), ] CustomDataset.load_annotations = MagicMock() results = [] for _ in range(2): height = np.random.randint(10, 30) weight = np.random.randint(10, 30) img = np.ones((height, weight, 3)) gt_bbox = np.concatenate([ np.random.randint(1, 5, (2, 2)), np.random.randint(1, 5, (2, 2)) + 5 ], axis=1) gt_labels = np.random.randint(0, 80, 2) results.append(dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img)) CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx]) dataset_a = CustomDataset( ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='') len_a = 2 cat_ids_list_a = [ np.random.randint(0, 80, num).tolist() for num in np.random.randint(1, 20, len_a) ] dataset_a.data_infos = MagicMock() dataset_a.data_infos.__len__.return_value = len_a dataset_a.get_cat_ids = MagicMock( side_effect=lambda idx: cat_ids_list_a[idx]) multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline, dynamic_scale) for idx in range(len_a): results_ = multi_image_mix_dataset[idx] assert results_['img'].shape == (dynamic_scale[0], dynamic_scale[1], 3) # test skip_type_keys multi_image_mix_dataset = MultiImageMixDataset( dataset_a, pipeline, dynamic_scale, skip_type_keys=('MixUp', 'RandomFlip', 'Resize', 'Pad')) for idx in range(len_a): results_ = multi_image_mix_dataset[idx] assert results_['img'].shape == (img_scale[0], img_scale[1], 3)