Merge pull request !2532 from ava/NNNewMastertags/v0.6.0-beta
| @@ -63,12 +63,14 @@ | |||
| #include "dataset/kernels/image/random_horizontal_flip_bbox_op.h" | |||
| #include "dataset/kernels/image/random_horizontal_flip_op.h" | |||
| #include "dataset/kernels/image/random_resize_op.h" | |||
| #include "dataset/kernels/image/random_resize_with_bbox_op.h" | |||
| #include "dataset/kernels/image/random_rotation_op.h" | |||
| #include "dataset/kernels/image/random_vertical_flip_op.h" | |||
| #include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" | |||
| #include "dataset/kernels/image/rescale_op.h" | |||
| #include "dataset/kernels/image/resize_bilinear_op.h" | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/resize_with_bbox_op.h" | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/no_op.h" | |||
| #include "dataset/text/kernels/jieba_tokenizer_op.h" | |||
| @@ -348,6 +350,18 @@ void bindTensorOps1(py::module *m) { | |||
| .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"), | |||
| py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); | |||
| (void)py::class_<ResizeWithBBoxOp, TensorOp, std::shared_ptr<ResizeWithBBoxOp>>( | |||
| *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") | |||
| .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"), | |||
| py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, | |||
| py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); | |||
| (void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>( | |||
| *m, "RandomResizeWithBBoxOp", | |||
| "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") | |||
| .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), | |||
| py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); | |||
| (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>( | |||
| *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") | |||
| .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"), | |||
| @@ -25,4 +25,6 @@ add_library(kernels-image OBJECT | |||
| resize_bilinear_op.cc | |||
| resize_op.cc | |||
| uniform_aug_op.cc | |||
| resize_with_bbox_op.cc | |||
| random_resize_with_bbox_op.cc | |||
| ) | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/image/random_resize_with_bbox_op.h" | |||
| #include "dataset/kernels/image/resize_with_bbox_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; | |||
| Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| // Randomly selects from the following four interpolation methods | |||
| // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area | |||
| interpolation_ = static_cast<InterpolationMode>(distribution_(random_generator_)); | |||
| RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H | |||
| #define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H | |||
| #include <memory> | |||
| #include <random> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/resize_with_bbox_op.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { | |||
| public: | |||
| // Default values, also used by python_bindings.cc | |||
| static const int32_t kDefTargetWidth; | |||
| explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { | |||
| random_generator_.seed(GetSeed()); | |||
| } | |||
| ~RandomResizeWithBBoxOp() = default; | |||
| // Description: A function that prints info about the node | |||
| void Print(std::ostream &out) const override { | |||
| out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; | |||
| } | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| private: | |||
| std::mt19937 random_generator_; | |||
| std::uniform_int_distribution<int> distribution_{0, 3}; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/image/resize_with_bbox_op.h" | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/core/cv_tensor.h" | |||
| #include "dataset/core/pybind_support.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| BOUNDING_BOX_CHECK(input); | |||
| int32_t input_h = input[0]->shape()[0]; | |||
| int32_t input_w = input[0]->shape()[1]; | |||
| output->resize(2); | |||
| (*output)[1] = std::move(input[1]); // move boxes over to output | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input[0])); | |||
| RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast<Tensor>(input_cv), &(*output)[0])); | |||
| int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox | |||
| int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox | |||
| size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor | |||
| RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H | |||
| #define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/status.h" | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class ResizeWithBBoxOp : public ResizeOp { | |||
| public: | |||
| // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor | |||
| explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, | |||
| InterpolationMode mInterpolation = kDefInterpolation) | |||
| : ResizeOp(size_1, size_2, mInterpolation) {} | |||
| ~ResizeWithBBoxOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H | |||
| @@ -265,6 +265,7 @@ class BoundingBoxAugment(cde.BoundingBoxAugmentOp): | |||
| ratio (float, optional): Ratio of bounding boxes to apply augmentation on. | |||
| Range: [0,1] (default=0.3). | |||
| """ | |||
| @check_bounding_box_augment_cpp | |||
| def __init__(self, transform, ratio=0.3): | |||
| self.ratio = ratio | |||
| @@ -302,6 +303,36 @@ class Resize(cde.ResizeOp): | |||
| super().__init__(*size, interpoltn) | |||
| class ResizeWithBBox(cde.ResizeWithBBoxOp): | |||
| """ | |||
| Resize the input image to the given size and adjust the bounding boxes accordingly. | |||
| Args: | |||
| size (int or sequence): The output size of the resized image. | |||
| If size is an int, smaller edge of the image will be resized to this value with | |||
| the same image aspect ratio. | |||
| If size is a sequence of length 2, it should be (height, width). | |||
| interpolation (Inter mode, optional): Image interpolation mode (default=Inter.LINEAR). | |||
| It can be any of [Inter.LINEAR, Inter.NEAREST, Inter.BICUBIC]. | |||
| - Inter.LINEAR, means interpolation method is bilinear interpolation. | |||
| - Inter.NEAREST, means interpolation method is nearest-neighbor interpolation. | |||
| - Inter.BICUBIC, means interpolation method is bicubic interpolation. | |||
| """ | |||
| @check_resize_interpolation | |||
| def __init__(self, size, interpolation=Inter.LINEAR): | |||
| self.size = size | |||
| self.interpolation = interpolation | |||
| interpoltn = DE_C_INTER_MODE[interpolation] | |||
| if isinstance(size, int): | |||
| super().__init__(size, interpolation=interpoltn) | |||
| else: | |||
| super().__init__(*size, interpoltn) | |||
| class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): | |||
| """ | |||
| Crop the input image to a random size and aspect ratio and adjust the Bounding Boxes accordingly | |||
| @@ -326,6 +357,7 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): | |||
| max_attempts (int, optional): The maximum number of attempts to propose a valid | |||
| crop_area (default=10). If exceeded, fall back to use center_crop instead. | |||
| """ | |||
| @check_random_resize_crop | |||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | |||
| interpolation=Inter.BILINEAR, max_attempts=10): | |||
| @@ -499,6 +531,27 @@ class RandomResize(cde.RandomResizeOp): | |||
| super().__init__(*size) | |||
| class RandomResizeWithBBox(cde.RandomResizeWithBBoxOp): | |||
| """ | |||
| Tensor operation to resize the input image using a randomly selected interpolation mode and adjust | |||
| the bounding boxes accordingly. | |||
| Args: | |||
| size (int or sequence): The output size of the resized image. | |||
| If size is an int, smaller edge of the image will be resized to this value with | |||
| the same image aspect ratio. | |||
| If size is a sequence of length 2, it should be (height, width). | |||
| """ | |||
| @check_resize | |||
| def __init__(self, size): | |||
| self.size = size | |||
| if isinstance(size, int): | |||
| super().__init__(size) | |||
| else: | |||
| super().__init__(*size) | |||
| class HWC2CHW(cde.ChannelSwapOp): | |||
| """ | |||
| Transpose the input image; shape (H, W, C) to shape (C, H, W). | |||
| @@ -0,0 +1,265 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the random resize with bounding boxes op in DE | |||
| """ | |||
| from enum import Enum | |||
| import matplotlib.pyplot as plt | |||
| import matplotlib.patches as patches | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| return bboxes | |||
| class BoxType(Enum): | |||
| """ | |||
| Defines box types for test cases | |||
| """ | |||
| WidthOverflow = 1 | |||
| HeightOverflow = 2 | |||
| NegativeXY = 3 | |||
| OnEdge = 4 | |||
| WrongShape = 5 | |||
| class AddBadAnnotation: # pylint: disable=too-few-public-methods | |||
| """ | |||
| Used to add erroneous bounding boxes to object detection pipelines. | |||
| Usage: | |||
| >>> # Adds a box that covers the whole image. Good for testing edge cases | |||
| >>> de = de.map(input_columns=["image", "annotation"], | |||
| >>> output_columns=["image", "annotation"], | |||
| >>> operations=AddBadAnnotation(BoxType.OnEdge)) | |||
| """ | |||
| def __init__(self, box_type): | |||
| self.box_type = box_type | |||
| def __call__(self, img, bboxes): | |||
| """ | |||
| Used to generate erroneous bounding box examples on given img. | |||
| :param img: image where the bounding boxes are. | |||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| :return: bboxes with bad examples added | |||
| """ | |||
| height = img.shape[0] | |||
| width = img.shape[1] | |||
| if self.box_type == BoxType.WidthOverflow: | |||
| # use box that overflows on width | |||
| return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.HeightOverflow: | |||
| # use box that overflows on height | |||
| return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.NegativeXY: | |||
| # use box with negative xy | |||
| return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.OnEdge: | |||
| # use box that covers the whole image | |||
| return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.WrongShape: | |||
| # use box that covers the whole image | |||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||
| return img, bboxes | |||
| def check_bad_box(data, box_type, expected_error): | |||
| try: | |||
| test_op = c_vision.RandomResizeWithBBox(100) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||
| data = data.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to use width overflow | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=AddBadAnnotation(box_type)) # Add column for "annotation" | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert expected_error in str(e) | |||
| def add_bounding_boxes(axis, bboxes): | |||
| """ | |||
| :param axis: axis to modify | |||
| :param bboxes: bounding boxes to draw on the axis | |||
| :return: None | |||
| """ | |||
| for bbox in bboxes: | |||
| rect = patches.Rectangle((bbox[0], bbox[1]), | |||
| bbox[2], bbox[3], | |||
| linewidth=1, edgecolor='r', facecolor='none') | |||
| # Add the patch to the Axes | |||
| axis.add_patch(rect) | |||
| def visualize(unaugmented_data, augment_data): | |||
| for idx, (un_aug_item, aug_item) in \ | |||
| enumerate(zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): | |||
| axis = plt.subplot(141) | |||
| plt.imshow(un_aug_item["image"]) | |||
| add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes | |||
| plt.title("Original" + str(idx + 1)) | |||
| logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) | |||
| axis = plt.subplot(142) | |||
| plt.imshow(aug_item["image"]) | |||
| add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes | |||
| plt.title("Augmented" + str(idx + 1)) | |||
| logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") | |||
| plt.show() | |||
| def test_random_resize_with_bbox_op(plot=False): | |||
| """ | |||
| Test random_resize_with_bbox_op | |||
| """ | |||
| logger.info("Test random resize with bbox") | |||
| # original images | |||
| data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # augmented images | |||
| data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| data_original = data_original.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| data_augmented = data_augmented.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # define map operations | |||
| test_op = c_vision.RandomResizeWithBBox(100) # input value being the target size of resizeOp | |||
| data_augmented = data_augmented.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], operations=[test_op]) | |||
| if plot: | |||
| visualize(data_original, data_augmented) | |||
| def test_random_resize_with_bbox_invalid_bounds(): | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.WrongShape, "4 features") | |||
| def test_random_resize_with_bbox_invalid_size(): | |||
| """ | |||
| Test random_resize_with_bbox_op | |||
| """ | |||
| logger.info("Test random resize with bbox with invalid target size") | |||
| # original images | |||
| data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| data = data.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # negative target size as input | |||
| try: | |||
| test_op = c_vision.RandomResizeWithBBox(-10) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| print(e) | |||
| assert "Input is not" in str(e) | |||
| # zero target size as input | |||
| try: | |||
| test_op = c_vision.RandomResizeWithBBox(0) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| # invalid input shape | |||
| try: | |||
| test_op = c_vision.RandomResizeWithBBox((10, 10, 10)) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Size should be" in str(e) | |||
| if __name__ == "__main__": | |||
| test_random_resize_with_bbox_op(plot=False) | |||
| test_random_resize_with_bbox_invalid_bounds() | |||
| test_random_resize_with_bbox_invalid_size() | |||
| @@ -0,0 +1,295 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the resize with bounding boxes op in DE | |||
| """ | |||
| from enum import Enum | |||
| import numpy as np | |||
| import matplotlib.patches as patches | |||
| import matplotlib.pyplot as plt | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| import mindspore.dataset as ds | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| return bboxes | |||
| class BoxType(Enum): | |||
| """ | |||
| Defines box types for test cases | |||
| """ | |||
| WidthOverflow = 1 | |||
| HeightOverflow = 2 | |||
| NegativeXY = 3 | |||
| OnEdge = 4 | |||
| WrongShape = 5 | |||
| class AddBadAnnotation: # pylint: disable=too-few-public-methods | |||
| """ | |||
| Used to add erroneous bounding boxes to object detection pipelines. | |||
| Usage: | |||
| >>> # Adds a box that covers the whole image. Good for testing edge cases | |||
| >>> de = de.map(input_columns=["image", "annotation"], | |||
| >>> output_columns=["image", "annotation"], | |||
| >>> operations=AddBadAnnotation(BoxType.OnEdge)) | |||
| """ | |||
| def __init__(self, box_type): | |||
| self.box_type = box_type | |||
| def __call__(self, img, bboxes): | |||
| """ | |||
| Used to generate erroneous bounding box examples on given img. | |||
| :param img: image where the bounding boxes are. | |||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| :return: bboxes with bad examples added | |||
| """ | |||
| height = img.shape[0] | |||
| width = img.shape[1] | |||
| if self.box_type == BoxType.WidthOverflow: | |||
| # use box that overflows on width | |||
| return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.HeightOverflow: | |||
| # use box that overflows on height | |||
| return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.NegativeXY: | |||
| # use box with negative xy | |||
| return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.OnEdge: | |||
| # use box that covers the whole image | |||
| return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) | |||
| if self.box_type == BoxType.WrongShape: | |||
| # use box that covers the whole image | |||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||
| return img, bboxes | |||
| def check_bad_box(data, box_type, expected_error): | |||
| try: | |||
| test_op = c_vision.ResizeWithBBox(100) | |||
| data = data.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to use width overflow | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=AddBadAnnotation(box_type)) # Add column for "annotation" | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert expected_error in str(e) | |||
| def add_bounding_boxes(axis, bboxes): | |||
| """ | |||
| :param axis: axis to modify | |||
| :param bboxes: bounding boxes to draw on the axis | |||
| :return: None | |||
| """ | |||
| for bbox in bboxes: | |||
| rect = patches.Rectangle((bbox[0], bbox[1]), | |||
| bbox[2], bbox[3], | |||
| linewidth=1, edgecolor='r', facecolor='none') | |||
| # Add the patch to the Axes | |||
| axis.add_patch(rect) | |||
| def visualize(unaugmented_data, augment_data): | |||
| for idx, (un_aug_item, aug_item) in enumerate( | |||
| zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): | |||
| axis = plt.subplot(141) | |||
| plt.imshow(un_aug_item["image"]) | |||
| add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes | |||
| plt.title("Original" + str(idx + 1)) | |||
| logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) | |||
| axis = plt.subplot(142) | |||
| plt.imshow(aug_item["image"]) | |||
| add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes | |||
| plt.title("Augmented" + str(idx + 1)) | |||
| logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") | |||
| plt.show() | |||
| def test_resize_with_bbox_op(plot=False): | |||
| """ | |||
| Test resize_with_bbox_op | |||
| """ | |||
| logger.info("Test resize with bbox") | |||
| # original images | |||
| data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # augmented images | |||
| data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| data_original = data_original.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| data_augmented = data_augmented.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # define map operations | |||
| test_op = c_vision.ResizeWithBBox(100) # input value being the target size of resizeOp | |||
| data_augmented = data_augmented.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], operations=[test_op]) | |||
| if plot: | |||
| visualize(data_original, data_augmented) | |||
| def test_resize_with_bbox_invalid_bounds(): | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_box(data_voc2, BoxType.WrongShape, "4 features") | |||
| def test_resize_with_bbox_invalid_size(): | |||
| """ | |||
| Test resize_with_bbox_op | |||
| """ | |||
| logger.info("Test resize with bbox with invalid target size") | |||
| # original images | |||
| data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| data = data.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # negative target size as input | |||
| try: | |||
| test_op = c_vision.ResizeWithBBox(-10) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| # zero target size as input | |||
| try: | |||
| test_op = c_vision.ResizeWithBBox(0) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| # invalid input shape | |||
| try: | |||
| test_op = c_vision.ResizeWithBBox((10, 10, 10)) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Size should be" in str(e) | |||
| def test_resize_with_bbox_invalid_interpolation(): | |||
| """ | |||
| Test resize_with_bbox_op | |||
| """ | |||
| logger.info("Test resize with bbox with invalid interpolation size") | |||
| # original images | |||
| data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| data = data.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # invalid interpolation | |||
| try: | |||
| test_op = c_vision.ResizeWithBBox(100, interpolation="invalid") | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "interpolation" in str(e) | |||
| if __name__ == "__main__": | |||
| test_resize_with_bbox_op(plot=False) | |||
| test_resize_with_bbox_invalid_bounds() | |||
| test_resize_with_bbox_invalid_size() | |||
| test_resize_with_bbox_invalid_interpolation() | |||