diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc index b820779ed1..fbaf2c9326 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc @@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow BOUNDING_BOX_CHECK(input); CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); - (*output).push_back(nullptr); // init memory for return vector - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // move boxes over to output size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc index 2be37f1da3..c873307afd 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc @@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) int32_t padded_image_h; int32_t padded_image_w; - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // since some boxes may be removed bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc index c6aa8450a8..ffea851eac 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc @@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); } - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); return VerticalFlip(input[0], &(*output)[0]); diff --git a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz new file mode 100644 index 0000000000..0c220fd09d Binary files /dev/null and b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz new file mode 100644 index 0000000000..a909cbe88c Binary files /dev/null and b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz new file mode 100644 index 0000000000..aba6fe97b0 Binary files /dev/null and b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz differ diff --git a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py index 90269a7027..359b527dd1 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py @@ -13,17 +13,17 @@ # limitations under the License. # ============================================================================== """ -Testing RandomCropAndResizeWithBBox op +Testing RandomCropAndResizeWithBBox op in DE """ import numpy as np - -import matplotlib.pyplot as plt -import matplotlib.patches as patches - import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision from mindspore import log as logger +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + +GENERATE_GOLDEN = False # updated VOC dataset with correct annotations DATA_DIR = "../data/dataset/testVOC2012_2" @@ -31,8 +31,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2" def fix_annotate(bboxes): """ - Update Current VOC dataset format to Proposed HQ BBox format - + Fix annotations to format followed by mindspore. :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format """ @@ -46,112 +45,22 @@ def fix_annotate(bboxes): return bboxes -def add_bounding_boxes(ax, bboxes): - for bbox in bboxes: - rect = patches.Rectangle((bbox[0], bbox[1]), - bbox[2], bbox[3], - linewidth=1, edgecolor='r', facecolor='none') - # Add the patch to the Axes - ax.add_patch(rect) - - -def vis_check(orig, aug): - if not isinstance(orig, list) or not isinstance(aug, list): - return False - if len(orig) != len(aug): - return False - return True - - -def visualize(orig, aug): - - if not vis_check(orig, aug): - return - - plotrows = 3 - compset = int(len(orig)/plotrows) - - orig, aug = np.array(orig), np.array(aug) - - orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]] - aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]] - - for ix, allData in enumerate(zip(orig, aug)): - base_ix = ix * plotrows # will signal what base level we're on - fig, axs = plt.subplots(len(allData[0]), 2) - fig.tight_layout(pad=1.5) - - for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): - cur_ix = base_ix + x - - axs[x, 0].imshow(dataA["image"]) - add_bounding_boxes(axs[x, 0], dataA["annotation"]) - axs[x, 0].title.set_text("Original" + str(cur_ix+1)) - print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"]) - - axs[x, 1].imshow(dataB["image"]) - add_bounding_boxes(axs[x, 1], dataB["annotation"]) - axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) - print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n") - - plt.show() - -# Functions to pass to Gen for creating invalid bounding boxes - - -def gen_bad_bbox_neg_xy(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10] - return im, bbox - - -def gen_bad_bbox_overflow_width(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w + 10, im_h - 10] - return im, bbox - - -def gen_bad_bbox_overflow_height(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w - 10, im_h + 10] - return im, bbox - - -def gen_bad_bbox_wrong_shape(im, bbox): - bbox = np.array([[0, 0, 0]]).astype(bbox.dtype) - return im, bbox - - -badGenFuncs = [gen_bad_bbox_neg_xy, - gen_bad_bbox_overflow_width, - gen_bad_bbox_overflow_height, - gen_bad_bbox_wrong_shape] - - -assertVal = ["min_x", - "is out of bounds of the image", - "is out of bounds of the image", - "4 features"] - - -# Gen Edge case BBox -def gen_bbox_edge(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w, im_h] - return im, bbox - - -def test_c_random_resized_crop_with_bbox_op(plot_vis=False): +def test_random_resized_crop_with_bbox_op_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes to compare and test + Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, + tests with MD5 check, expected to pass """ + logger.info("test_random_resized_crop_with_bbox_op_c") + + original_seed = config_get_set_seed(23415) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -164,6 +73,9 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False): columns_order=["image", "annotation"], operations=[test_op]) # Add column for "annotation" + filename = "random_resized_crop_with_bbox_01_c_result.npz" + save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) + unaugSamp, augSamp = [], [] for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): @@ -171,20 +83,26 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) -def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): +def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes to compare and test + Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, + tests on dynamically generated edge case, expected to pass """ + logger.info("test_random_resized_crop_with_bbox_op_edge_c") + # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -192,17 +110,17 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): output_columns=["annotation"], operations=fix_annotate) - # Modify BBoxes to serve as valid edge cases - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + # maps to convert data into valid edge case data + dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[gen_bbox_edge]) + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - # map to apply ops + # Test Op added to list of Operations here dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) unaugSamp, augSamp = [], [] @@ -211,21 +129,22 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) -def test_c_random_resized_crop_with_bbox_op_invalid(): +def test_random_resized_crop_with_bbox_op_invalid_c(): """ - Prints images side by side with and without Aug applied + bboxes to compare and test + Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError """ - # Load dataset # only loading the to AugDataset as test will fail on this + logger.info("test_random_resized_crop_with_bbox_op_invalid_c") + + # Load dataset, only Augmented Dataset as test will raise ValueError dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) try: # If input range of scale is not in the order of (min, max), ValueError will be raised. test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) - # maps to fix annotations to HQ standard dataVoc2 = dataVoc2.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -243,10 +162,11 @@ def test_c_random_resized_crop_with_bbox_op_invalid(): assert "Input range is not valid" in str(err) -def test_c_random_resized_crop_with_bbox_op_invalid2(): +def test_random_resized_crop_with_bbox_op_invalid2_c(): """ - Prints images side by side with and without Aug applied + bboxes to compare and test + Tests RandomResizedCropWithBBox Op on invalid constructor parameters, expected to raise ValueError """ + logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") # Load dataset # only loading the to AugDataset as test will fail on this dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -254,7 +174,6 @@ def test_c_random_resized_crop_with_bbox_op_invalid2(): # If input range of ratio is not in the order of (min, max), ValueError will be raised. test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) - # maps to fix annotations to HQ standard dataVoc2 = dataVoc2.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -272,41 +191,26 @@ def test_c_random_resized_crop_with_bbox_op_invalid2(): assert "Input range is not valid" in str(err) -def test_c_random_resized_crop_with_bbox_op_bad(): - # Should Fail - Errors logged to logger - for ix, badFunc in enumerate(badGenFuncs): - try: - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", - decode=True, shuffle=False) - - test_op = c_vision.RandomVerticalFlipWithBBox(1) - - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[badFunc]) - - # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) - - for _ in dataVoc2.create_dict_iterator(): - break # first sample will cause exception +def test_random_resized_crop_with_bbox_op_bad_c(): + """ + Test RandomCropWithBBox op with invalid bounding boxes, expected to catch multiple errors. + """ + logger.info("test_random_resized_crop_with_bbox_op_bad_c") + test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - except RuntimeError as err: - logger.info("Got an exception in DE: {}".format(str(err))) - assert assertVal[ix] in str(err) + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") if __name__ == "__main__": - test_c_random_resized_crop_with_bbox_op(plot_vis=True) - test_c_random_resized_crop_with_bbox_op_edge(plot_vis=True) - test_c_random_resized_crop_with_bbox_op_invalid() - test_c_random_resized_crop_with_bbox_op_invalid2() - test_c_random_resized_crop_with_bbox_op_bad() + test_random_resized_crop_with_bbox_op_c(plot_vis=True) + test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) + test_random_resized_crop_with_bbox_op_invalid_c() + test_random_resized_crop_with_bbox_op_invalid2_c() + test_random_resized_crop_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/test_random_crop_with_bbox.py b/tests/ut/python/dataset/test_random_crop_with_bbox.py index 7f5fa46512..08233cb3bc 100644 --- a/tests/ut/python/dataset/test_random_crop_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_with_bbox.py @@ -13,18 +13,18 @@ # limitations under the License. # ============================================================================== """ -Testing RandomCropWithBBox op +Testing RandomCropWithBBox op in DE """ - import numpy as np -import matplotlib.pyplot as plt -import matplotlib.patches as patches - import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.utils as mode from mindspore import log as logger +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + +GENERATE_GOLDEN = False # updated VOC dataset with correct annotations DATA_DIR = "../data/dataset/testVOC2012_2" @@ -32,8 +32,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2" def fix_annotate(bboxes): """ - Update Current VOC dataset format to Proposed HQ BBox format - + Fix annotations to format followed by mindspore. :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format """ @@ -47,113 +46,19 @@ def fix_annotate(bboxes): return bboxes -def add_bounding_boxes(ax, bboxes): - for bbox in bboxes: - rect = patches.Rectangle((bbox[0], bbox[1]), - bbox[2], bbox[3], - linewidth=1, edgecolor='r', facecolor='none') - # Add the patch to the Axes - ax.add_patch(rect) - - -def vis_check(orig, aug): - if not isinstance(orig, list) or not isinstance(aug, list): - return False - if len(orig) != len(aug): - return False - return True - - -def visualize(orig, aug): - - if not vis_check(orig, aug): - return - - plotrows = 3 - compset = int(len(orig)/plotrows) - - orig, aug = np.array(orig), np.array(aug) - - orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]] - aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]] - - for ix, allData in enumerate(zip(orig, aug)): - base_ix = ix * plotrows # will signal what base level we're on - fig, axs = plt.subplots(len(allData[0]), 2) - fig.tight_layout(pad=1.5) - - for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): - cur_ix = base_ix + x - - axs[x, 0].imshow(dataA["image"]) - add_bounding_boxes(axs[x, 0], dataA["annotation"]) - axs[x, 0].title.set_text("Original" + str(cur_ix+1)) - print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"]) - - axs[x, 1].imshow(dataB["image"]) - add_bounding_boxes(axs[x, 1], dataB["annotation"]) - axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) - print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n") - - plt.show() - -# Functions to pass to Gen for creating invalid bounding boxes - - -def gen_bad_bbox_neg_xy(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10] - return im, bbox - - -def gen_bad_bbox_overflow_width(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w + 10, im_h - 10] - return im, bbox - - -def gen_bad_bbox_overflow_height(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w - 10, im_h + 10] - return im, bbox - - -def gen_bad_bbox_wrong_shape(im, bbox): - bbox = np.array([[0, 0, 0]]).astype(bbox.dtype) - return im, bbox - - -badGenFuncs = [gen_bad_bbox_neg_xy, - gen_bad_bbox_overflow_width, - gen_bad_bbox_overflow_height, - gen_bad_bbox_wrong_shape] - - -assertVal = ["min_x", - "is out of bounds of the image", - "is out of bounds of the image", - "4 features"] - - -# Gen Edge case BBox -def gen_bbox_edge(im, bbox): - im_h, im_w = im.shape[0], im.shape[1] - bbox[0][:4] = [0, 0, im_w, im_h] - return im, bbox - - def test_random_crop_with_bbox_op_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied """ + logger.info("test_random_crop_with_bbox_op_c") + # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - # define test OP with values to match existing Op unit - test + # define test OP with values to match existing Op UT test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -173,14 +78,17 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) def test_random_crop_with_bbox_op2_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes - With Fill Value + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + with md5 check, expected to pass """ + logger.info("test_random_crop_with_bbox_op2_c") + original_seed = config_get_set_seed(593447) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -189,7 +97,6 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -202,6 +109,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): columns_order=["image", "annotation"], operations=[test_op]) # Add column for "annotation" + filename = "random_crop_with_bbox_01_c_result.npz" + save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) + unaugSamp, augSamp = [], [] for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): @@ -209,14 +119,20 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) def test_random_crop_with_bbox_op3_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes - With Padding Mode passed + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + with Padding Mode explicitly passed """ + logger.info("test_random_crop_with_bbox_op3_c") + # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -224,7 +140,6 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -244,14 +159,16 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) def test_random_crop_with_bbox_op_edge_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes - Testing for an Edge case + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + applied on dynamically generated edge case, expected to pass """ + logger.info("test_random_crop_with_bbox_op_edge_c") + # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -259,7 +176,6 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -267,17 +183,17 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): output_columns=["annotation"], operations=fix_annotate) - # Modify BBoxes to serve as valid edge cases - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + # maps to convert data into valid edge case data + dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[gen_bbox_edge]) + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - # map to apply ops + # Test Op added to list of Operations here dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) unaugSamp, augSamp = [], [] @@ -286,13 +202,15 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): augSamp.append(Aug) if plot_vis: - visualize(unaugSamp, augSamp) + visualize_with_bounding_boxes(unaugSamp, augSamp) def test_random_crop_with_bbox_op_invalid_c(): """ - Checking for invalid params passed to Aug Constructor + Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueError """ + logger.info("test_random_crop_with_bbox_op_invalid_c") + # Load dataset dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -300,8 +218,6 @@ def test_random_crop_with_bbox_op_invalid_c(): # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox([512, 512, 375]) - # maps to fix annotations to HQ standard - dataVoc2 = dataVoc2.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -320,35 +236,20 @@ def test_random_crop_with_bbox_op_invalid_c(): def test_random_crop_with_bbox_op_bad_c(): - # Should Fail - Errors logged to logger - for ix, badFunc in enumerate(badGenFuncs): - try: - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", - decode=True, shuffle=False) - - test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) - - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[badFunc]) - - # map to apply ops - dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) - - for _ in dataVoc2.create_dict_iterator(): - break # first sample will cause exception - - except RuntimeError as err: - logger.info("Got an exception in DE: {}".format(str(err))) - assert assertVal[ix] in str(err) + """ + Tests RandomCropWithBBox Op with invalid bounding boxes, expected to catch multiple errors. + """ + logger.info("test_random_crop_with_bbox_op_bad_c") + test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) + + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py index b1bb4bc459..72c40c0cad 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py @@ -13,14 +13,17 @@ # limitations under the License. # ============================================================================== """ -Testing RandomVerticalFlipWithBBox op +Testing RandomVerticalFlipWithBBox op in DE """ - +import numpy as np import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision from mindspore import log as logger -from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + +GENERATE_GOLDEN = False # updated VOC dataset with correct annotations DATA_DIR = "../data/dataset/testVOC2012_2" @@ -28,10 +31,9 @@ DATA_DIR = "../data/dataset/testVOC2012_2" def fix_annotate(bboxes): """ - Update Current VOC dataset format to Proposed HQ BBox format - - :param bboxes: as [label, x_min, y_min, w, h, truncate, difficult] - :return: annotation as [x_min, y_min, w, h, label, truncate, difficult] + Fix annotations to format followed by mindspore. + :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format + :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format """ for bbox in bboxes: tmp = bbox[0] @@ -45,9 +47,9 @@ def fix_annotate(bboxes): def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes to - compare and test + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied """ + logger.info("test_random_vertical_flip_with_bbox_op_c") # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -57,7 +59,6 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomVerticalFlipWithBBox(1) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -82,9 +83,12 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes to - compare and test + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, + tests with MD5 check, expected to pass """ + logger.info("test_random_vertical_flip_with_bbox_op_rand_c") + original_seed = config_get_set_seed(29847) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", @@ -93,9 +97,8 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - test_op = c_vision.RandomVerticalFlipWithBBox(0.6) + test_op = c_vision.RandomVerticalFlipWithBBox(0.8) - # maps to fix annotations to HQ standard dataVoc1 = dataVoc1.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -108,6 +111,56 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): columns_order=["image", "annotation"], operations=[test_op]) + filename = "random_vertical_flip_with_bbox_01_c_result.npz" + save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, + applied on dynamically generated edge case, expected to pass + """ + logger.info("test_random_vertical_flip_with_bbox_op_edge_c") + dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) + + dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) + + test_op = c_vision.RandomVerticalFlipWithBBox(1) + + dataVoc1 = dataVoc1.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + dataVoc2 = dataVoc2.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # maps to convert data into valid edge case data + dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) + + # Test Op added to list of Operations here + dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) + unaugSamp, augSamp = [], [] for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): @@ -119,16 +172,15 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): def test_random_vertical_flip_with_bbox_op_invalid_c(): - # Should Fail - # Load dataset + """ + Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError + """ + logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) try: test_op = c_vision.RandomVerticalFlipWithBBox(2) - - # maps to fix annotations to HQ standard - dataVoc2 = dataVoc2.map(input_columns=["annotation"], output_columns=["annotation"], operations=fix_annotate) @@ -148,9 +200,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): def test_random_vertical_flip_with_bbox_op_bad_c(): """ - Test RandomHorizontalFlipWithBBox op with invalid bounding boxes + Tests RandomVerticalFlipWithBBox Op with invalid bounding boxes, expected to catch multiple errors """ - logger.info("test_random_horizontal_bbox_invalid_bounds_c") + logger.info("test_random_vertical_flip_with_bbox_op_bad_c") test_op = c_vision.RandomVerticalFlipWithBBox(1) data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -166,5 +218,6 @@ def test_random_vertical_flip_with_bbox_op_bad_c(): if __name__ == "__main__": test_random_vertical_flip_with_bbox_op_c(plot_vis=True) test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=True) + test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=True) test_random_vertical_flip_with_bbox_op_invalid_c() test_random_vertical_flip_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 00a2c7ef57..2a8e93cd0b 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -312,34 +312,39 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3): if len(orig) != len(aug) or not orig: return - comp_set = int(len(orig)/plot_rows) + batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together + split_point = batch_size * plot_rows + orig, aug = np.array(orig), np.array(aug) if len(orig) > plot_rows: - orig = np.split(orig[:comp_set*plot_rows], comp_set) + [orig[comp_set*plot_rows:]] - aug = np.split(aug[:comp_set*plot_rows], comp_set) + [aug[comp_set*plot_rows:]] + # Create batches of required size and add remainder to last batch + orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added + aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) else: orig = [orig] aug = [aug] for ix, allData in enumerate(zip(orig, aug)): - base_ix = ix * plot_rows # will signal what base level we're on + base_ix = ix * plot_rows # current batch starting index + curPlot = len(allData[0]) - sub_plot_count = 2 if (len(allData[0]) < 2) else len(allData[0]) # if 1 image remains, create subplot for 2 to simplify axis selection - fig, axs = plt.subplots(sub_plot_count, 2) + fig, axs = plt.subplots(curPlot, 2) fig.tight_layout(pad=1.5) for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): cur_ix = base_ix + x + (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row - axs[x, 0].imshow(dataA["image"]) - add_bounding_boxes(axs[x, 0], dataA["annotation"]) - axs[x, 0].title.set_text("Original" + str(cur_ix+1)) - logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) + axA.imshow(dataA["image"]) + add_bounding_boxes(axA, dataA["annotation"]) + axA.title.set_text("Original" + str(cur_ix+1)) - axs[x, 1].imshow(dataB["image"]) - add_bounding_boxes(axs[x, 1], dataB["annotation"]) - axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) + axB.imshow(dataB["image"]) + add_bounding_boxes(axB, dataB["annotation"]) + axB.title.set_text("Augmented" + str(cur_ix+1)) + + logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB["annotation"])) plt.show()