| @@ -353,10 +353,10 @@ void bindTensorOps1(py::module *m) { | |||||
| .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"), | .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"), | ||||
| py::arg("NumOps") = UniformAugOp::kDefNumOps); | py::arg("NumOps") = UniformAugOp::kDefNumOps); | ||||
| (void)py::class_<BoundingBoxAugOp, TensorOp, std::shared_ptr<BoundingBoxAugOp>>( | |||||
| *m, "BoundingBoxAugOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") | |||||
| (void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>( | |||||
| *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") | |||||
| .def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"), | .def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"), | ||||
| py::arg("ratio") = BoundingBoxAugOp::defRatio); | |||||
| py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); | |||||
| (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>( | (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>( | ||||
| *m, "ResizeBilinearOp", | *m, "ResizeBilinearOp", | ||||
| @@ -23,12 +23,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| const float BoundingBoxAugOp::defRatio = 0.3; | |||||
| const float BoundingBoxAugmentOp::kDefRatio = 0.3; | |||||
| BoundingBoxAugOp::BoundingBoxAugOp(std::shared_ptr<TensorOp> transform, float ratio) | |||||
| : ratio_(ratio), transform_(std::move(transform)) {} | |||||
| BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr<TensorOp> transform, float ratio) | |||||
| : ratio_(ratio), transform_(std::move(transform)) { | |||||
| rnd_.seed(GetSeed()); | |||||
| } | |||||
| Status BoundingBoxAugOp::Compute(const TensorRow &input, TensorRow *output) { | |||||
| Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) { | |||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid | BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid | ||||
| uint32_t num_of_boxes = input[1]->shape()[0]; | uint32_t num_of_boxes = input[1]->shape()[0]; | ||||
| @@ -37,8 +39,7 @@ Status BoundingBoxAugOp::Compute(const TensorRow &input, TensorRow *output) { | |||||
| std::vector<uint32_t> selected_boxes; | std::vector<uint32_t> selected_boxes; | ||||
| for (uint32_t i = 0; i < num_of_boxes; i++) boxes[i] = i; | for (uint32_t i = 0; i < num_of_boxes; i++) boxes[i] = i; | ||||
| // sample bboxes according to ratio picked by user | // sample bboxes according to ratio picked by user | ||||
| std::random_device rd; | |||||
| std::sample(boxes.begin(), boxes.end(), std::back_inserter(selected_boxes), num_to_aug, std::mt19937(rd())); | |||||
| std::sample(boxes.begin(), boxes.end(), std::back_inserter(selected_boxes), num_to_aug, rnd_); | |||||
| std::shared_ptr<Tensor> crop_out; | std::shared_ptr<Tensor> crop_out; | ||||
| std::shared_ptr<Tensor> res_out; | std::shared_ptr<Tensor> res_out; | ||||
| std::shared_ptr<CVTensor> input_restore = CVTensor::AsCVTensor(input[0]); | std::shared_ptr<CVTensor> input_restore = CVTensor::AsCVTensor(input[0]); | ||||
| @@ -24,33 +24,35 @@ | |||||
| #include "dataset/core/tensor.h" | #include "dataset/core/tensor.h" | ||||
| #include "dataset/kernels/tensor_op.h" | #include "dataset/kernels/tensor_op.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| #include "dataset/util/random.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class BoundingBoxAugOp : public TensorOp { | |||||
| class BoundingBoxAugmentOp : public TensorOp { | |||||
| public: | public: | ||||
| // Default values, also used by python_bindings.cc | // Default values, also used by python_bindings.cc | ||||
| static const float defRatio; | |||||
| static const float kDefRatio; | |||||
| // Constructor for BoundingBoxAugmentOp | // Constructor for BoundingBoxAugmentOp | ||||
| // @param std::shared_ptr<TensorOp> transform transform: C++ opration to apply on select bounding boxes | // @param std::shared_ptr<TensorOp> transform transform: C++ opration to apply on select bounding boxes | ||||
| // @param float ratio: ratio of bounding boxes to have the transform applied on | // @param float ratio: ratio of bounding boxes to have the transform applied on | ||||
| BoundingBoxAugOp(std::shared_ptr<TensorOp> transform, float ratio); | |||||
| BoundingBoxAugmentOp(std::shared_ptr<TensorOp> transform, float ratio); | |||||
| ~BoundingBoxAugOp() override = default; | |||||
| ~BoundingBoxAugmentOp() override = default; | |||||
| // Provide stream operator for displaying it | // Provide stream operator for displaying it | ||||
| friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugOp &so) { | |||||
| friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) { | |||||
| so.Print(out); | so.Print(out); | ||||
| return out; | return out; | ||||
| } | } | ||||
| void Print(std::ostream &out) const override { out << "BoundingBoxAugOp"; } | |||||
| void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } | |||||
| Status Compute(const TensorRow &input, TensorRow *output) override; | Status Compute(const TensorRow &input, TensorRow *output) override; | ||||
| private: | private: | ||||
| float ratio_; | float ratio_; | ||||
| std::mt19937 rnd_; | |||||
| std::shared_ptr<TensorOp> transform_; | std::shared_ptr<TensorOp> transform_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -29,20 +29,19 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow | |||||
| BOUNDING_BOX_CHECK(input); | BOUNDING_BOX_CHECK(input); | ||||
| if (distribution_(rnd_)) { | if (distribution_(rnd_)) { | ||||
| // To test bounding boxes algorithm, create random bboxes from image dims | // To test bounding boxes algorithm, create random bboxes from image dims | ||||
| size_t numOfBBoxes = input[1]->shape()[0]; // set to give number of bboxes | |||||
| float imgCenter = (input[0]->shape()[1] / 2); // get the center of the image | |||||
| size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes | |||||
| float img_center = (input[0]->shape()[1] / 2); // get the center of the image | |||||
| for (int i = 0; i < numOfBBoxes; i++) { | |||||
| for (int i = 0; i < num_of_boxes; i++) { | |||||
| uint32_t b_w = 0; // bounding box width | uint32_t b_w = 0; // bounding box width | ||||
| uint32_t min_x = 0; | uint32_t min_x = 0; | ||||
| // get the required items | // get the required items | ||||
| input[1]->GetItemAt<uint32_t>(&min_x, {i, 0}); | input[1]->GetItemAt<uint32_t>(&min_x, {i, 0}); | ||||
| input[1]->GetItemAt<uint32_t>(&b_w, {i, 2}); | input[1]->GetItemAt<uint32_t>(&b_w, {i, 2}); | ||||
| // do the flip | // do the flip | ||||
| float diff = imgCenter - min_x; // get distance from min_x to center | |||||
| uint32_t refl_min_x = diff + imgCenter; // get reflection of min_x | |||||
| uint32_t new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one | |||||
| float diff = img_center - min_x; // get distance from min_x to center | |||||
| uint32_t refl_min_x = diff + img_center; // get reflection of min_x | |||||
| uint32_t new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one | |||||
| input[1]->SetItemAt<uint32_t>({i, 0}, new_min_x); | input[1]->SetItemAt<uint32_t>({i, 0}, new_min_x); | ||||
| } | } | ||||
| (*output).push_back(nullptr); | (*output).push_back(nullptr); | ||||
| @@ -45,6 +45,10 @@ | |||||
| #define BOUNDING_BOX_CHECK(input) \ | #define BOUNDING_BOX_CHECK(input) \ | ||||
| do { \ | do { \ | ||||
| if (input[1]->shape().Size() < 2) { \ | |||||
| return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ | |||||
| "Bounding boxes shape should have at least two dims"); \ | |||||
| } \ | |||||
| uint32_t num_of_features = input[1]->shape()[1]; \ | uint32_t num_of_features = input[1]->shape()[1]; \ | ||||
| if (num_of_features < 4) { \ | if (num_of_features < 4) { \ | ||||
| return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ | return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ | ||||
| @@ -254,13 +254,16 @@ class RandomVerticalFlipWithBBox(cde.RandomVerticalFlipWithBBoxOp): | |||||
| super().__init__(prob) | super().__init__(prob) | ||||
| class BoundingBoxAug(cde.BoundingBoxAugOp): | |||||
| class BoundingBoxAugment(cde.BoundingBoxAugmentOp): | |||||
| """ | """ | ||||
| Flip the input image vertically, randomly with a given probability. | |||||
| Apply a given image transform on a random selection of bounding box regions | |||||
| of a given image. | |||||
| Args: | Args: | ||||
| transform: C++ operation (python OPs are not accepted). | |||||
| ratio (float): Ratio of bounding boxes to apply augmentation on. Range: [0,1] (default=1). | |||||
| transform: C++ transformation function to be applied on random selection | |||||
| of bounding box regions of a given image. | |||||
| ratio (float, optional): Ratio of bounding boxes to apply augmentation on. | |||||
| Range: [0,1] (default=0.3). | |||||
| """ | """ | ||||
| @check_bounding_box_augment_cpp | @check_bounding_box_augment_cpp | ||||
| def __init__(self, transform, ratio=0.3): | def __init__(self, transform, ratio=0.3): | ||||
| @@ -862,13 +862,13 @@ def check_bounding_box_augment_cpp(method): | |||||
| transform = kwargs.get("transform") | transform = kwargs.get("transform") | ||||
| if "ratio" in kwargs: | if "ratio" in kwargs: | ||||
| ratio = kwargs.get("ratio") | ratio = kwargs.get("ratio") | ||||
| if not isinstance(ratio, float) and not isinstance(ratio, int): | |||||
| raise ValueError("Ratio should be an int or float.") | |||||
| if ratio is not None: | if ratio is not None: | ||||
| check_value(ratio, [0., 1.]) | check_value(ratio, [0., 1.]) | ||||
| kwargs["ratio"] = ratio | kwargs["ratio"] = ratio | ||||
| else: | else: | ||||
| ratio = 0.3 | ratio = 0.3 | ||||
| if not isinstance(ratio, float) and not isinstance(ratio, int): | |||||
| raise ValueError("Ratio should be an int or float.") | |||||
| if not isinstance(transform, TensorOp): | if not isinstance(transform, TensorOp): | ||||
| raise ValueError("Transform can only be a C++ operation.") | raise ValueError("Transform can only be a C++ operation.") | ||||
| kwargs["transform"] = transform | kwargs["transform"] = transform | ||||
| @@ -16,7 +16,7 @@ | |||||
| Testing the bounding box augment op in DE | Testing the bounding box augment op in DE | ||||
| """ | """ | ||||
| from enum import Enum | from enum import Enum | ||||
| from mindspore import log as logger | |||||
| import mindspore.log as logger | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||
| @@ -39,59 +39,36 @@ class BoxType(Enum): | |||||
| WrongShape = 5 | WrongShape = 5 | ||||
| class AddBadAnnotation: # pylint: disable=too-few-public-methods | |||||
| def add_bad_annotation(img, bboxes, box_type): | |||||
| """ | """ | ||||
| Used to add erroneous bounding boxes to object detection pipelines. | |||||
| Usage: | |||||
| >>> # Adds a box that covers the whole image. Good for testing edge cases | |||||
| >>> de = de.map(input_columns=["image", "annotation"], | |||||
| >>> output_columns=["image", "annotation"], | |||||
| >>> operations=AddBadAnnotation(BoxType.OnEdge)) | |||||
| Used to generate erroneous bounding box examples on given img. | |||||
| :param img: image where the bounding boxes are. | |||||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||||
| :param box_type: type of bad box | |||||
| :return: bboxes with bad examples added | |||||
| """ | """ | ||||
| height = img.shape[0] | |||||
| width = img.shape[1] | |||||
| if box_type == BoxType.WidthOverflow: | |||||
| # use box that overflows on width | |||||
| return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32) | |||||
| def __init__(self, box_type): | |||||
| self.box_type = box_type | |||||
| def __call__(self, img, bboxes): | |||||
| """ | |||||
| Used to generate erroneous bounding box examples on given img. | |||||
| :param img: image where the bounding boxes are. | |||||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||||
| :return: bboxes with bad examples added | |||||
| """ | |||||
| height = img.shape[0] | |||||
| width = img.shape[1] | |||||
| if self.box_type == BoxType.WidthOverflow: | |||||
| # use box that overflows on width | |||||
| return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.HeightOverflow: | |||||
| # use box that overflows on height | |||||
| return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.NegativeXY: | |||||
| # use box with negative xy | |||||
| return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.OnEdge: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.WrongShape: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||||
| return img, bboxes | |||||
| def h_flip(image): | |||||
| """ | |||||
| Apply the random_horizontal | |||||
| """ | |||||
| if box_type == BoxType.HeightOverflow: | |||||
| # use box that overflows on height | |||||
| return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32) | |||||
| if box_type == BoxType.NegativeXY: | |||||
| # use box with negative xy | |||||
| return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if box_type == BoxType.OnEdge: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| # with the seed provided in this test case, it will always flip. | |||||
| # that's why we flip here too | |||||
| image = image[:, ::-1, :] | |||||
| return image | |||||
| if box_type == BoxType.WrongShape: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||||
| return img, bboxes | |||||
| def check_bad_box(data, box_type, expected_error): | def check_bad_box(data, box_type, expected_error): | ||||
| @@ -102,8 +79,8 @@ def check_bad_box(data, box_type, expected_error): | |||||
| :return: None | :return: None | ||||
| """ | """ | ||||
| try: | try: | ||||
| test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1), | |||||
| 1) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), | |||||
| 1) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | |||||
| data = data.map(input_columns=["annotation"], | data = data.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| operations=fix_annotate) | operations=fix_annotate) | ||||
| @@ -111,7 +88,7 @@ def check_bad_box(data, box_type, expected_error): | |||||
| data = data.map(input_columns=["image", "annotation"], | data = data.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=AddBadAnnotation(box_type)) # Add column for "annotation" | |||||
| operations=lambda img, bboxes: add_bad_annotation(img, bboxes, box_type)) | |||||
| # map to apply ops | # map to apply ops | ||||
| data = data.map(input_columns=["image", "annotation"], | data = data.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| @@ -187,7 +164,7 @@ def test_bounding_box_augment_with_rotation_op(plot=False): | |||||
| data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| test_op = c_vision.BoundingBoxAug(c_vision.RandomRotation(90), 1) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) | |||||
| # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | ||||
| # maps to fix annotations to minddata standard | # maps to fix annotations to minddata standard | ||||
| @@ -216,7 +193,7 @@ def test_bounding_box_augment_with_crop_op(plot=False): | |||||
| data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| test_op = c_vision.BoundingBoxAug(c_vision.RandomCrop(90), 1) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(90), 1) | |||||
| # maps to fix annotations to minddata standard | # maps to fix annotations to minddata standard | ||||
| data_voc1 = data_voc1.map(input_columns=["annotation"], | data_voc1 = data_voc1.map(input_columns=["annotation"], | ||||
| @@ -244,7 +221,7 @@ def test_bounding_box_augment_valid_ratio_c(plot=False): | |||||
| data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | ||||
| test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1), 0.9) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) | |||||
| # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) | ||||
| # maps to fix annotations to minddata standard | # maps to fix annotations to minddata standard | ||||
| @@ -274,7 +251,7 @@ def test_bounding_box_augment_invalid_ratio_c(): | |||||
| try: | try: | ||||
| # ratio range is from 0 - 1 | # ratio range is from 0 - 1 | ||||
| test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1), 1.5) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5) | |||||
| # maps to fix annotations to minddata standard | # maps to fix annotations to minddata standard | ||||
| data_voc1 = data_voc1.map(input_columns=["annotation"], | data_voc1 = data_voc1.map(input_columns=["annotation"], | ||||
| output_columns=["annotation"], | output_columns=["annotation"], | ||||
| @@ -16,12 +16,12 @@ | |||||
| Testing the random horizontal flip with bounding boxes op in DE | Testing the random horizontal flip with bounding boxes op in DE | ||||
| """ | """ | ||||
| from enum import Enum | from enum import Enum | ||||
| from mindspore import log as logger | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||
| import matplotlib.patches as patches | import matplotlib.patches as patches | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.log as logger | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @@ -38,57 +38,42 @@ class BoxType(Enum): | |||||
| OnEdge = 4 | OnEdge = 4 | ||||
| WrongShape = 5 | WrongShape = 5 | ||||
| class AddBadAnnotation: # pylint: disable=too-few-public-methods | |||||
| def add_bad_annotation(img, bboxes, box_type): | |||||
| """ | """ | ||||
| Used to add erroneous bounding boxes to object detection pipelines. | |||||
| Usage: | |||||
| >>> # Adds a box that covers the whole image. Good for testing edge cases | |||||
| >>> de = de.map(input_columns=["image", "annotation"], | |||||
| >>> output_columns=["image", "annotation"], | |||||
| >>> operations=AddBadAnnotation(BoxType.OnEdge)) | |||||
| Used to generate erroneous bounding box examples on given img. | |||||
| :param img: image where the bounding boxes are. | |||||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||||
| :param box_type: type of bad box | |||||
| :return: bboxes with bad examples added | |||||
| """ | """ | ||||
| height = img.shape[0] | |||||
| width = img.shape[1] | |||||
| if box_type == BoxType.WidthOverflow: | |||||
| # use box that overflows on width | |||||
| return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32) | |||||
| def __init__(self, box_type): | |||||
| self.box_type = box_type | |||||
| def __call__(self, img, bboxes): | |||||
| """ | |||||
| Used to generate erroneous bounding box examples on given img. | |||||
| :param img: image where the bounding boxes are. | |||||
| :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format | |||||
| :return: bboxes with bad examples added | |||||
| """ | |||||
| height = img.shape[0] | |||||
| width = img.shape[1] | |||||
| if self.box_type == BoxType.WidthOverflow: | |||||
| # use box that overflows on width | |||||
| return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if box_type == BoxType.HeightOverflow: | |||||
| # use box that overflows on height | |||||
| return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.HeightOverflow: | |||||
| # use box that overflows on height | |||||
| return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32) | |||||
| if box_type == BoxType.NegativeXY: | |||||
| # use box with negative xy | |||||
| return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.NegativeXY: | |||||
| # use box with negative xy | |||||
| return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if box_type == BoxType.OnEdge: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.OnEdge: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32) | |||||
| if self.box_type == BoxType.WrongShape: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||||
| return img, bboxes | |||||
| if box_type == BoxType.WrongShape: | |||||
| # use box that covers the whole image | |||||
| return img, np.array([[0, 0, width - 1]]).astype(np.uint32) | |||||
| return img, bboxes | |||||
| def h_flip(image): | def h_flip(image): | ||||
| """ | """ | ||||
| Apply the random_horizontal | Apply the random_horizontal | ||||
| """ | """ | ||||
| # with the seed provided in this test case, it will always flip. | |||||
| # that's why we flip here too | # that's why we flip here too | ||||
| image = image[:, ::-1, :] | image = image[:, ::-1, :] | ||||
| return image | return image | ||||
| @@ -111,7 +96,7 @@ def check_bad_box(data, box_type, expected_error): | |||||
| data = data.map(input_columns=["image", "annotation"], | data = data.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||
| columns_order=["image", "annotation"], | columns_order=["image", "annotation"], | ||||
| operations=AddBadAnnotation(box_type)) # Add column for "annotation" | |||||
| operations=lambda img, bboxes: add_bad_annotation(img, bboxes, box_type)) | |||||
| # map to apply ops | # map to apply ops | ||||
| data = data.map(input_columns=["image", "annotation"], | data = data.map(input_columns=["image", "annotation"], | ||||
| output_columns=["image", "annotation"], | output_columns=["image", "annotation"], | ||||