Merge pull request !2590 from danishnxt/AugOps2tags/v0.6.0-beta
| @@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow | |||||
| BOUNDING_BOX_CHECK(input); | BOUNDING_BOX_CHECK(input); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); | CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); | ||||
| (*output).push_back(nullptr); // init memory for return vector | |||||
| (*output).push_back(nullptr); | |||||
| output->resize(2); | |||||
| (*output)[1] = std::move(input[1]); // move boxes over to output | (*output)[1] = std::move(input[1]); // move boxes over to output | ||||
| size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor | size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor | ||||
| @@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) | |||||
| int32_t padded_image_h; | int32_t padded_image_h; | ||||
| int32_t padded_image_w; | int32_t padded_image_w; | ||||
| (*output).push_back(nullptr); | |||||
| (*output).push_back(nullptr); | |||||
| output->resize(2); | |||||
| (*output)[1] = std::move(input[1]); // since some boxes may be removed | (*output)[1] = std::move(input[1]); // since some boxes may be removed | ||||
| bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches | bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches | ||||
| @@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * | |||||
| RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); | RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); | ||||
| } | } | ||||
| (*output).push_back(nullptr); | |||||
| (*output).push_back(nullptr); | |||||
| output->resize(2); | |||||
| (*output)[1] = std::move(input[1]); | (*output)[1] = std::move(input[1]); | ||||
| return VerticalFlip(input[0], &(*output)[0]); | return VerticalFlip(input[0], &(*output)[0]); | ||||
| @@ -13,17 +13,17 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | """ | ||||
| Testing RandomCropAndResizeWithBBox op | |||||
| Testing RandomCropAndResizeWithBBox op in DE | |||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import matplotlib.pyplot as plt | |||||
| import matplotlib.patches as patches | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||||
| GENERATE_GOLDEN = False | |||||
| # updated VOC dataset with correct annotations | # updated VOC dataset with correct annotations | ||||
| DATA_DIR = "../data/dataset/testVOC2012_2" | DATA_DIR = "../data/dataset/testVOC2012_2" | ||||
| @@ -31,8 +31,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2" | |||||
| def fix_annotate(bboxes): | def fix_annotate(bboxes): | ||||
| """ | """ | ||||
| Update Current VOC dataset format to Proposed HQ BBox format | |||||
| Fix annotations to format followed by mindspore. | |||||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | ||||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | ||||
| """ | """ | ||||
| @@ -46,112 +45,22 @@ def fix_annotate(bboxes): | |||||
| return bboxes | return bboxes | ||||
| def add_bounding_boxes(ax, bboxes): | |||||
| for bbox in bboxes: | |||||
| rect = patches.Rectangle((bbox[0], bbox[1]), | |||||
| bbox[2], bbox[3], | |||||
| linewidth=1, edgecolor='r', facecolor='none') | |||||
| # Add the patch to the Axes | |||||
| ax.add_patch(rect) | |||||
| def vis_check(orig, aug): | |||||
| if not isinstance(orig, list) or not isinstance(aug, list): | |||||
| return False | |||||
| if len(orig) != len(aug): | |||||
| return False | |||||
| return True | |||||
| def visualize(orig, aug): | |||||
| if not vis_check(orig, aug): | |||||
| return | |||||
| plotrows = 3 | |||||
| compset = int(len(orig)/plotrows) | |||||
| orig, aug = np.array(orig), np.array(aug) | |||||
| orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]] | |||||
| aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]] | |||||
| for ix, allData in enumerate(zip(orig, aug)): | |||||
| base_ix = ix * plotrows # will signal what base level we're on | |||||
| fig, axs = plt.subplots(len(allData[0]), 2) | |||||
| fig.tight_layout(pad=1.5) | |||||
| for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | |||||
| cur_ix = base_ix + x | |||||
| axs[x, 0].imshow(dataA["image"]) | |||||
| add_bounding_boxes(axs[x, 0], dataA["annotation"]) | |||||
| axs[x, 0].title.set_text("Original" + str(cur_ix+1)) | |||||
| print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"]) | |||||
| axs[x, 1].imshow(dataB["image"]) | |||||
| add_bounding_boxes(axs[x, 1], dataB["annotation"]) | |||||
| axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) | |||||
| print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n") | |||||
| plt.show() | |||||
| # Functions to pass to Gen for creating invalid bounding boxes | |||||
| def gen_bad_bbox_neg_xy(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_overflow_width(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w + 10, im_h - 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_overflow_height(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w - 10, im_h + 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_wrong_shape(im, bbox): | |||||
| bbox = np.array([[0, 0, 0]]).astype(bbox.dtype) | |||||
| return im, bbox | |||||
| badGenFuncs = [gen_bad_bbox_neg_xy, | |||||
| gen_bad_bbox_overflow_width, | |||||
| gen_bad_bbox_overflow_height, | |||||
| gen_bad_bbox_wrong_shape] | |||||
| assertVal = ["min_x", | |||||
| "is out of bounds of the image", | |||||
| "is out of bounds of the image", | |||||
| "4 features"] | |||||
| # Gen Edge case BBox | |||||
| def gen_bbox_edge(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w, im_h] | |||||
| return im, bbox | |||||
| def test_c_random_resized_crop_with_bbox_op(plot_vis=False): | |||||
| def test_random_resized_crop_with_bbox_op_c(plot_vis=False): | |||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||||
| Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, | |||||
| tests with MD5 check, expected to pass | |||||
| """ | """ | ||||
| logger.info("test_random_resized_crop_with_bbox_op_c") | |||||
| original_seed = config_get_set_seed(23415) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -164,6 +73,9 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False): | |||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[test_op]) # Add column for "annotation" | operations=[test_op]) # Add column for "annotation" | ||||
| filename = "random_resized_crop_with_bbox_01_c_result.npz" | |||||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||||
| unaugSamp, augSamp = [], [] | unaugSamp, augSamp = [], [] | ||||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | ||||
| @@ -171,20 +83,26 @@ def test_c_random_resized_crop_with_bbox_op(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| # Restore config setting | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): | |||||
| def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||||
| Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, | |||||
| tests on dynamically generated edge case, expected to pass | |||||
| """ | """ | ||||
| logger.info("test_random_resized_crop_with_bbox_op_edge_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -192,17 +110,17 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): | |||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| # Modify BBoxes to serve as valid edge cases | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| # maps to convert data into valid edge case data | |||||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[gen_bbox_edge]) | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||||
| # map to apply ops | |||||
| # Test Op added to list of Operations here | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[test_op]) # Add column for "annotation" | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||||
| unaugSamp, augSamp = [], [] | unaugSamp, augSamp = [], [] | ||||
| @@ -211,21 +129,22 @@ def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| def test_c_random_resized_crop_with_bbox_op_invalid(): | |||||
| def test_random_resized_crop_with_bbox_op_invalid_c(): | |||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||||
| Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError | |||||
| """ | """ | ||||
| # Load dataset # only loading the to AugDataset as test will fail on this | |||||
| logger.info("test_random_resized_crop_with_bbox_op_invalid_c") | |||||
| # Load dataset, only Augmented Dataset as test will raise ValueError | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| try: | try: | ||||
| # If input range of scale is not in the order of (min, max), ValueError will be raised. | # If input range of scale is not in the order of (min, max), ValueError will be raised. | ||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | dataVoc2 = dataVoc2.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -243,10 +162,11 @@ def test_c_random_resized_crop_with_bbox_op_invalid(): | |||||
| assert "Input range is not valid" in str(err) | assert "Input range is not valid" in str(err) | ||||
| def test_c_random_resized_crop_with_bbox_op_invalid2(): | |||||
| def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||||
| Tests RandomResizedCropWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||||
| """ | """ | ||||
| logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") | |||||
| # Load dataset # only loading the to AugDataset as test will fail on this | # Load dataset # only loading the to AugDataset as test will fail on this | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -254,7 +174,6 @@ def test_c_random_resized_crop_with_bbox_op_invalid2(): | |||||
| # If input range of ratio is not in the order of (min, max), ValueError will be raised. | # If input range of ratio is not in the order of (min, max), ValueError will be raised. | ||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | dataVoc2 = dataVoc2.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -272,41 +191,26 @@ def test_c_random_resized_crop_with_bbox_op_invalid2(): | |||||
| assert "Input range is not valid" in str(err) | assert "Input range is not valid" in str(err) | ||||
| def test_c_random_resized_crop_with_bbox_op_bad(): | |||||
| # Should Fail - Errors logged to logger | |||||
| for ix, badFunc in enumerate(badGenFuncs): | |||||
| try: | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||||
| output_columns=["annotation"], | |||||
| operations=fix_annotate) | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[badFunc]) | |||||
| # map to apply ops | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[test_op]) | |||||
| for _ in dataVoc2.create_dict_iterator(): | |||||
| break # first sample will cause exception | |||||
| def test_random_resized_crop_with_bbox_op_bad_c(): | |||||
| """ | |||||
| Test RandomCropWithBBox op with invalid bounding boxes, expected to catch multiple errors. | |||||
| """ | |||||
| logger.info("test_random_resized_crop_with_bbox_op_bad_c") | |||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||||
| except RuntimeError as err: | |||||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||||
| assert assertVal[ix] in str(err) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_c_random_resized_crop_with_bbox_op(plot_vis=True) | |||||
| test_c_random_resized_crop_with_bbox_op_edge(plot_vis=True) | |||||
| test_c_random_resized_crop_with_bbox_op_invalid() | |||||
| test_c_random_resized_crop_with_bbox_op_invalid2() | |||||
| test_c_random_resized_crop_with_bbox_op_bad() | |||||
| test_random_resized_crop_with_bbox_op_c(plot_vis=True) | |||||
| test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) | |||||
| test_random_resized_crop_with_bbox_op_invalid_c() | |||||
| test_random_resized_crop_with_bbox_op_invalid2_c() | |||||
| test_random_resized_crop_with_bbox_op_bad_c() | |||||
| @@ -13,18 +13,18 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | """ | ||||
| Testing RandomCropWithBBox op | |||||
| Testing RandomCropWithBBox op in DE | |||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import matplotlib.pyplot as plt | |||||
| import matplotlib.patches as patches | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| import mindspore.dataset.transforms.vision.utils as mode | import mindspore.dataset.transforms.vision.utils as mode | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||||
| GENERATE_GOLDEN = False | |||||
| # updated VOC dataset with correct annotations | # updated VOC dataset with correct annotations | ||||
| DATA_DIR = "../data/dataset/testVOC2012_2" | DATA_DIR = "../data/dataset/testVOC2012_2" | ||||
| @@ -32,8 +32,7 @@ DATA_DIR = "../data/dataset/testVOC2012_2" | |||||
| def fix_annotate(bboxes): | def fix_annotate(bboxes): | ||||
| """ | """ | ||||
| Update Current VOC dataset format to Proposed HQ BBox format | |||||
| Fix annotations to format followed by mindspore. | |||||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | ||||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | ||||
| """ | """ | ||||
| @@ -47,113 +46,19 @@ def fix_annotate(bboxes): | |||||
| return bboxes | return bboxes | ||||
| def add_bounding_boxes(ax, bboxes): | |||||
| for bbox in bboxes: | |||||
| rect = patches.Rectangle((bbox[0], bbox[1]), | |||||
| bbox[2], bbox[3], | |||||
| linewidth=1, edgecolor='r', facecolor='none') | |||||
| # Add the patch to the Axes | |||||
| ax.add_patch(rect) | |||||
| def vis_check(orig, aug): | |||||
| if not isinstance(orig, list) or not isinstance(aug, list): | |||||
| return False | |||||
| if len(orig) != len(aug): | |||||
| return False | |||||
| return True | |||||
| def visualize(orig, aug): | |||||
| if not vis_check(orig, aug): | |||||
| return | |||||
| plotrows = 3 | |||||
| compset = int(len(orig)/plotrows) | |||||
| orig, aug = np.array(orig), np.array(aug) | |||||
| orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]] | |||||
| aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]] | |||||
| for ix, allData in enumerate(zip(orig, aug)): | |||||
| base_ix = ix * plotrows # will signal what base level we're on | |||||
| fig, axs = plt.subplots(len(allData[0]), 2) | |||||
| fig.tight_layout(pad=1.5) | |||||
| for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | |||||
| cur_ix = base_ix + x | |||||
| axs[x, 0].imshow(dataA["image"]) | |||||
| add_bounding_boxes(axs[x, 0], dataA["annotation"]) | |||||
| axs[x, 0].title.set_text("Original" + str(cur_ix+1)) | |||||
| print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"]) | |||||
| axs[x, 1].imshow(dataB["image"]) | |||||
| add_bounding_boxes(axs[x, 1], dataB["annotation"]) | |||||
| axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) | |||||
| print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n") | |||||
| plt.show() | |||||
| # Functions to pass to Gen for creating invalid bounding boxes | |||||
| def gen_bad_bbox_neg_xy(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_overflow_width(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w + 10, im_h - 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_overflow_height(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w - 10, im_h + 10] | |||||
| return im, bbox | |||||
| def gen_bad_bbox_wrong_shape(im, bbox): | |||||
| bbox = np.array([[0, 0, 0]]).astype(bbox.dtype) | |||||
| return im, bbox | |||||
| badGenFuncs = [gen_bad_bbox_neg_xy, | |||||
| gen_bad_bbox_overflow_width, | |||||
| gen_bad_bbox_overflow_height, | |||||
| gen_bad_bbox_wrong_shape] | |||||
| assertVal = ["min_x", | |||||
| "is out of bounds of the image", | |||||
| "is out of bounds of the image", | |||||
| "4 features"] | |||||
| # Gen Edge case BBox | |||||
| def gen_bbox_edge(im, bbox): | |||||
| im_h, im_w = im.shape[0], im.shape[1] | |||||
| bbox[0][:4] = [0, 0, im_w, im_h] | |||||
| return im, bbox | |||||
| def test_random_crop_with_bbox_op_c(plot_vis=False): | def test_random_crop_with_bbox_op_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes | |||||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| # define test OP with values to match existing Op unit - test | |||||
| # define test OP with values to match existing Op UT | |||||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -173,14 +78,17 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| def test_random_crop_with_bbox_op2_c(plot_vis=False): | def test_random_crop_with_bbox_op2_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes | |||||
| With Fill Value | |||||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||||
| with md5 check, expected to pass | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op2_c") | |||||
| original_seed = config_get_set_seed(593447) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -189,7 +97,6 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -202,6 +109,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[test_op]) # Add column for "annotation" | operations=[test_op]) # Add column for "annotation" | ||||
| filename = "random_crop_with_bbox_01_c_result.npz" | |||||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||||
| unaugSamp, augSamp = [], [] | unaugSamp, augSamp = [], [] | ||||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | ||||
| @@ -209,14 +119,20 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| # Restore config setting | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_random_crop_with_bbox_op3_c(plot_vis=False): | def test_random_crop_with_bbox_op3_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes | |||||
| With Padding Mode passed | |||||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||||
| with Padding Mode explicitly passed | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op3_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -224,7 +140,6 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -244,14 +159,16 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes | |||||
| Testing for an Edge case | |||||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||||
| applied on dynamically generated edge case, expected to pass | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op_edge_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -259,7 +176,6 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -267,17 +183,17 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| # Modify BBoxes to serve as valid edge cases | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| # maps to convert data into valid edge case data | |||||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[gen_bbox_edge]) | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||||
| # map to apply ops | |||||
| # Test Op added to list of Operations here | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[test_op]) # Add column for "annotation" | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||||
| unaugSamp, augSamp = [], [] | unaugSamp, augSamp = [], [] | ||||
| @@ -286,13 +202,15 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| augSamp.append(Aug) | augSamp.append(Aug) | ||||
| if plot_vis: | if plot_vis: | ||||
| visualize(unaugSamp, augSamp) | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| def test_random_crop_with_bbox_op_invalid_c(): | def test_random_crop_with_bbox_op_invalid_c(): | ||||
| """ | """ | ||||
| Checking for invalid params passed to Aug Constructor | |||||
| Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op_invalid_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -300,8 +218,6 @@ def test_random_crop_with_bbox_op_invalid_c(): | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox([512, 512, 375]) | test_op = c_vision.RandomCropWithBBox([512, 512, 375]) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | dataVoc2 = dataVoc2.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -320,35 +236,20 @@ def test_random_crop_with_bbox_op_invalid_c(): | |||||
| def test_random_crop_with_bbox_op_bad_c(): | def test_random_crop_with_bbox_op_bad_c(): | ||||
| # Should Fail - Errors logged to logger | |||||
| for ix, badFunc in enumerate(badGenFuncs): | |||||
| try: | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||||
| output_columns=["annotation"], | |||||
| operations=fix_annotate) | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[badFunc]) | |||||
| # map to apply ops | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[test_op]) | |||||
| for _ in dataVoc2.create_dict_iterator(): | |||||
| break # first sample will cause exception | |||||
| except RuntimeError as err: | |||||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||||
| assert assertVal[ix] in str(err) | |||||
| """ | |||||
| Tests RandomCropWithBBox Op with invalid bounding boxes, expected to catch multiple errors. | |||||
| """ | |||||
| logger.info("test_random_crop_with_bbox_op_bad_c") | |||||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -13,14 +13,17 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | """ | ||||
| Testing RandomVerticalFlipWithBBox op | |||||
| Testing RandomVerticalFlipWithBBox op in DE | |||||
| """ | """ | ||||
| import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox | |||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||||
| GENERATE_GOLDEN = False | |||||
| # updated VOC dataset with correct annotations | # updated VOC dataset with correct annotations | ||||
| DATA_DIR = "../data/dataset/testVOC2012_2" | DATA_DIR = "../data/dataset/testVOC2012_2" | ||||
| @@ -28,10 +31,9 @@ DATA_DIR = "../data/dataset/testVOC2012_2" | |||||
| def fix_annotate(bboxes): | def fix_annotate(bboxes): | ||||
| """ | """ | ||||
| Update Current VOC dataset format to Proposed HQ BBox format | |||||
| :param bboxes: as [label, x_min, y_min, w, h, truncate, difficult] | |||||
| :return: annotation as [x_min, y_min, w, h, label, truncate, difficult] | |||||
| Fix annotations to format followed by mindspore. | |||||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||||
| """ | """ | ||||
| for bbox in bboxes: | for bbox in bboxes: | ||||
| tmp = bbox[0] | tmp = bbox[0] | ||||
| @@ -45,9 +47,9 @@ def fix_annotate(bboxes): | |||||
| def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to | |||||
| compare and test | |||||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied | |||||
| """ | """ | ||||
| logger.info("test_random_vertical_flip_with_bbox_op_c") | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | ||||
| decode=True, shuffle=False) | decode=True, shuffle=False) | ||||
| @@ -57,7 +59,6 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | test_op = c_vision.RandomVerticalFlipWithBBox(1) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -82,9 +83,12 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | |||||
| def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | ||||
| """ | """ | ||||
| Prints images side by side with and without Aug applied + bboxes to | |||||
| compare and test | |||||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, | |||||
| tests with MD5 check, expected to pass | |||||
| """ | """ | ||||
| logger.info("test_random_vertical_flip_with_bbox_op_rand_c") | |||||
| original_seed = config_get_set_seed(29847) | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | ||||
| @@ -93,9 +97,8 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | ||||
| decode=True, shuffle=False) | decode=True, shuffle=False) | ||||
| test_op = c_vision.RandomVerticalFlipWithBBox(0.6) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(0.8) | |||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | dataVoc1 = dataVoc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -108,6 +111,56 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=[test_op]) | operations=[test_op]) | ||||
| filename = "random_vertical_flip_with_bbox_01_c_result.npz" | |||||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||||
| unaugSamp, augSamp = [], [] | |||||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||||
| unaugSamp.append(unAug) | |||||
| augSamp.append(Aug) | |||||
| if plot_vis: | |||||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||||
| # Restore config setting | |||||
| ds.config.set_seed(original_seed) | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): | |||||
| """ | |||||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, | |||||
| applied on dynamically generated edge case, expected to pass | |||||
| """ | |||||
| logger.info("test_random_vertical_flip_with_bbox_op_edge_c") | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||||
| output_columns=["annotation"], | |||||
| operations=fix_annotate) | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||||
| output_columns=["annotation"], | |||||
| operations=fix_annotate) | |||||
| # maps to convert data into valid edge case data | |||||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||||
| # Test Op added to list of Operations here | |||||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "annotation"], | |||||
| columns_order=["image", "annotation"], | |||||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||||
| unaugSamp, augSamp = [], [] | unaugSamp, augSamp = [], [] | ||||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | ||||
| @@ -119,16 +172,15 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||||
| def test_random_vertical_flip_with_bbox_op_invalid_c(): | def test_random_vertical_flip_with_bbox_op_invalid_c(): | ||||
| # Should Fail | |||||
| # Load dataset | |||||
| """ | |||||
| Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||||
| """ | |||||
| logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | ||||
| decode=True, shuffle=False) | decode=True, shuffle=False) | ||||
| try: | try: | ||||
| test_op = c_vision.RandomVerticalFlipWithBBox(2) | test_op = c_vision.RandomVerticalFlipWithBBox(2) | ||||
| # maps to fix annotations to HQ standard | |||||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | dataVoc2 = dataVoc2.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -148,9 +200,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||||
| def test_random_vertical_flip_with_bbox_op_bad_c(): | def test_random_vertical_flip_with_bbox_op_bad_c(): | ||||
| """ | """ | ||||
| Test RandomHorizontalFlipWithBBox op with invalid bounding boxes | |||||
| Tests RandomVerticalFlipWithBBox Op with invalid bounding boxes, expected to catch multiple errors | |||||
| """ | """ | ||||
| logger.info("test_random_horizontal_bbox_invalid_bounds_c") | |||||
| logger.info("test_random_vertical_flip_with_bbox_op_bad_c") | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | test_op = c_vision.RandomVerticalFlipWithBBox(1) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| @@ -166,5 +218,6 @@ def test_random_vertical_flip_with_bbox_op_bad_c(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_vertical_flip_with_bbox_op_c(plot_vis=True) | test_random_vertical_flip_with_bbox_op_c(plot_vis=True) | ||||
| test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=True) | test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=True) | ||||
| test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=True) | |||||
| test_random_vertical_flip_with_bbox_op_invalid_c() | test_random_vertical_flip_with_bbox_op_invalid_c() | ||||
| test_random_vertical_flip_with_bbox_op_bad_c() | test_random_vertical_flip_with_bbox_op_bad_c() | ||||
| @@ -312,34 +312,39 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3): | |||||
| if len(orig) != len(aug) or not orig: | if len(orig) != len(aug) or not orig: | ||||
| return | return | ||||
| comp_set = int(len(orig)/plot_rows) | |||||
| batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together | |||||
| split_point = batch_size * plot_rows | |||||
| orig, aug = np.array(orig), np.array(aug) | orig, aug = np.array(orig), np.array(aug) | ||||
| if len(orig) > plot_rows: | if len(orig) > plot_rows: | ||||
| orig = np.split(orig[:comp_set*plot_rows], comp_set) + [orig[comp_set*plot_rows:]] | |||||
| aug = np.split(aug[:comp_set*plot_rows], comp_set) + [aug[comp_set*plot_rows:]] | |||||
| # Create batches of required size and add remainder to last batch | |||||
| orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added | |||||
| aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) | |||||
| else: | else: | ||||
| orig = [orig] | orig = [orig] | ||||
| aug = [aug] | aug = [aug] | ||||
| for ix, allData in enumerate(zip(orig, aug)): | for ix, allData in enumerate(zip(orig, aug)): | ||||
| base_ix = ix * plot_rows # will signal what base level we're on | |||||
| base_ix = ix * plot_rows # current batch starting index | |||||
| curPlot = len(allData[0]) | |||||
| sub_plot_count = 2 if (len(allData[0]) < 2) else len(allData[0]) # if 1 image remains, create subplot for 2 to simplify axis selection | |||||
| fig, axs = plt.subplots(sub_plot_count, 2) | |||||
| fig, axs = plt.subplots(curPlot, 2) | |||||
| fig.tight_layout(pad=1.5) | fig.tight_layout(pad=1.5) | ||||
| for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | ||||
| cur_ix = base_ix + x | cur_ix = base_ix + x | ||||
| (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row | |||||
| axs[x, 0].imshow(dataA["image"]) | |||||
| add_bounding_boxes(axs[x, 0], dataA["annotation"]) | |||||
| axs[x, 0].title.set_text("Original" + str(cur_ix+1)) | |||||
| logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) | |||||
| axA.imshow(dataA["image"]) | |||||
| add_bounding_boxes(axA, dataA["annotation"]) | |||||
| axA.title.set_text("Original" + str(cur_ix+1)) | |||||
| axs[x, 1].imshow(dataB["image"]) | |||||
| add_bounding_boxes(axs[x, 1], dataB["annotation"]) | |||||
| axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) | |||||
| axB.imshow(dataB["image"]) | |||||
| add_bounding_boxes(axB, dataB["annotation"]) | |||||
| axB.title.set_text("Augmented" + str(cur_ix+1)) | |||||
| logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) | |||||
| logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB["annotation"])) | logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB["annotation"])) | ||||
| plt.show() | plt.show() | ||||