Browse Source

addressed comments from reviews

tags/v0.5.0-beta
islam_amin 5 years ago
parent
commit
edc42c5b85
9 changed files with 99 additions and 128 deletions
  1. +3
    -3
      mindspore/ccsrc/dataset/api/python_bindings.cc
  2. +7
    -6
      mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc
  3. +8
    -6
      mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h
  4. +6
    -7
      mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc
  5. +4
    -0
      mindspore/ccsrc/dataset/kernels/tensor_op.h
  6. +7
    -4
      mindspore/dataset/transforms/vision/c_transforms.py
  7. +2
    -2
      mindspore/dataset/transforms/vision/validators.py
  8. +34
    -57
      tests/ut/python/dataset/test_bounding_box_augment.py
  9. +28
    -43
      tests/ut/python/dataset/test_random_horizontal_flip_bbox.py

+ 3
- 3
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -353,10 +353,10 @@ void bindTensorOps1(py::module *m) {
.def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"), .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
py::arg("NumOps") = UniformAugOp::kDefNumOps); py::arg("NumOps") = UniformAugOp::kDefNumOps);


(void)py::class_<BoundingBoxAugOp, TensorOp, std::shared_ptr<BoundingBoxAugOp>>(
*m, "BoundingBoxAugOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
(void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>(
*m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
.def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"), .def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"),
py::arg("ratio") = BoundingBoxAugOp::defRatio);
py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio);


(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>( (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
*m, "ResizeBilinearOp", *m, "ResizeBilinearOp",


+ 7
- 6
mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc View File

@@ -23,12 +23,14 @@


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
const float BoundingBoxAugOp::defRatio = 0.3;
const float BoundingBoxAugmentOp::kDefRatio = 0.3;


BoundingBoxAugOp::BoundingBoxAugOp(std::shared_ptr<TensorOp> transform, float ratio)
: ratio_(ratio), transform_(std::move(transform)) {}
BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr<TensorOp> transform, float ratio)
: ratio_(ratio), transform_(std::move(transform)) {
rnd_.seed(GetSeed());
}


Status BoundingBoxAugOp::Compute(const TensorRow &input, TensorRow *output) {
Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output); IO_CHECK_VECTOR(input, output);
BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid
uint32_t num_of_boxes = input[1]->shape()[0]; uint32_t num_of_boxes = input[1]->shape()[0];
@@ -37,8 +39,7 @@ Status BoundingBoxAugOp::Compute(const TensorRow &input, TensorRow *output) {
std::vector<uint32_t> selected_boxes; std::vector<uint32_t> selected_boxes;
for (uint32_t i = 0; i < num_of_boxes; i++) boxes[i] = i; for (uint32_t i = 0; i < num_of_boxes; i++) boxes[i] = i;
// sample bboxes according to ratio picked by user // sample bboxes according to ratio picked by user
std::random_device rd;
std::sample(boxes.begin(), boxes.end(), std::back_inserter(selected_boxes), num_to_aug, std::mt19937(rd()));
std::sample(boxes.begin(), boxes.end(), std::back_inserter(selected_boxes), num_to_aug, rnd_);
std::shared_ptr<Tensor> crop_out; std::shared_ptr<Tensor> crop_out;
std::shared_ptr<Tensor> res_out; std::shared_ptr<Tensor> res_out;
std::shared_ptr<CVTensor> input_restore = CVTensor::AsCVTensor(input[0]); std::shared_ptr<CVTensor> input_restore = CVTensor::AsCVTensor(input[0]);


+ 8
- 6
mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h View File

@@ -24,33 +24,35 @@
#include "dataset/core/tensor.h" #include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h" #include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"
#include "dataset/util/random.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
class BoundingBoxAugOp : public TensorOp {
class BoundingBoxAugmentOp : public TensorOp {
public: public:
// Default values, also used by python_bindings.cc // Default values, also used by python_bindings.cc
static const float defRatio;
static const float kDefRatio;


// Constructor for BoundingBoxAugmentOp // Constructor for BoundingBoxAugmentOp
// @param std::shared_ptr<TensorOp> transform transform: C++ opration to apply on select bounding boxes // @param std::shared_ptr<TensorOp> transform transform: C++ opration to apply on select bounding boxes
// @param float ratio: ratio of bounding boxes to have the transform applied on // @param float ratio: ratio of bounding boxes to have the transform applied on
BoundingBoxAugOp(std::shared_ptr<TensorOp> transform, float ratio);
BoundingBoxAugmentOp(std::shared_ptr<TensorOp> transform, float ratio);


~BoundingBoxAugOp() override = default;
~BoundingBoxAugmentOp() override = default;


// Provide stream operator for displaying it // Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugOp &so) {
friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) {
so.Print(out); so.Print(out);
return out; return out;
} }


void Print(std::ostream &out) const override { out << "BoundingBoxAugOp"; }
void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; }


Status Compute(const TensorRow &input, TensorRow *output) override; Status Compute(const TensorRow &input, TensorRow *output) override;


private: private:
float ratio_; float ratio_;
std::mt19937 rnd_;
std::shared_ptr<TensorOp> transform_; std::shared_ptr<TensorOp> transform_;
}; };
} // namespace dataset } // namespace dataset


+ 6
- 7
mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc View File

@@ -29,20 +29,19 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input); BOUNDING_BOX_CHECK(input);
if (distribution_(rnd_)) { if (distribution_(rnd_)) {
// To test bounding boxes algorithm, create random bboxes from image dims // To test bounding boxes algorithm, create random bboxes from image dims
size_t numOfBBoxes = input[1]->shape()[0]; // set to give number of bboxes
float imgCenter = (input[0]->shape()[1] / 2); // get the center of the image
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
float img_center = (input[0]->shape()[1] / 2); // get the center of the image


for (int i = 0; i < numOfBBoxes; i++) {
for (int i = 0; i < num_of_boxes; i++) {
uint32_t b_w = 0; // bounding box width uint32_t b_w = 0; // bounding box width
uint32_t min_x = 0; uint32_t min_x = 0;
// get the required items // get the required items
input[1]->GetItemAt<uint32_t>(&min_x, {i, 0}); input[1]->GetItemAt<uint32_t>(&min_x, {i, 0});
input[1]->GetItemAt<uint32_t>(&b_w, {i, 2}); input[1]->GetItemAt<uint32_t>(&b_w, {i, 2});
// do the flip // do the flip
float diff = imgCenter - min_x; // get distance from min_x to center
uint32_t refl_min_x = diff + imgCenter; // get reflection of min_x
uint32_t new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one

float diff = img_center - min_x; // get distance from min_x to center
uint32_t refl_min_x = diff + img_center; // get reflection of min_x
uint32_t new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one
input[1]->SetItemAt<uint32_t>({i, 0}, new_min_x); input[1]->SetItemAt<uint32_t>({i, 0}, new_min_x);
} }
(*output).push_back(nullptr); (*output).push_back(nullptr);


+ 4
- 0
mindspore/ccsrc/dataset/kernels/tensor_op.h View File

@@ -45,6 +45,10 @@


#define BOUNDING_BOX_CHECK(input) \ #define BOUNDING_BOX_CHECK(input) \
do { \ do { \
if (input[1]->shape().Size() < 2) { \
return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \
"Bounding boxes shape should have at least two dims"); \
} \
uint32_t num_of_features = input[1]->shape()[1]; \ uint32_t num_of_features = input[1]->shape()[1]; \
if (num_of_features < 4) { \ if (num_of_features < 4) { \
return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \


+ 7
- 4
mindspore/dataset/transforms/vision/c_transforms.py View File

@@ -254,13 +254,16 @@ class RandomVerticalFlipWithBBox(cde.RandomVerticalFlipWithBBoxOp):
super().__init__(prob) super().__init__(prob)




class BoundingBoxAug(cde.BoundingBoxAugOp):
class BoundingBoxAugment(cde.BoundingBoxAugmentOp):
""" """
Flip the input image vertically, randomly with a given probability.
Apply a given image transform on a random selection of bounding box regions
of a given image.


Args: Args:
transform: C++ operation (python OPs are not accepted).
ratio (float): Ratio of bounding boxes to apply augmentation on. Range: [0,1] (default=1).
transform: C++ transformation function to be applied on random selection
of bounding box regions of a given image.
ratio (float, optional): Ratio of bounding boxes to apply augmentation on.
Range: [0,1] (default=0.3).
""" """
@check_bounding_box_augment_cpp @check_bounding_box_augment_cpp
def __init__(self, transform, ratio=0.3): def __init__(self, transform, ratio=0.3):


+ 2
- 2
mindspore/dataset/transforms/vision/validators.py View File

@@ -862,13 +862,13 @@ def check_bounding_box_augment_cpp(method):
transform = kwargs.get("transform") transform = kwargs.get("transform")
if "ratio" in kwargs: if "ratio" in kwargs:
ratio = kwargs.get("ratio") ratio = kwargs.get("ratio")
if not isinstance(ratio, float) and not isinstance(ratio, int):
raise ValueError("Ratio should be an int or float.")
if ratio is not None: if ratio is not None:
check_value(ratio, [0., 1.]) check_value(ratio, [0., 1.])
kwargs["ratio"] = ratio kwargs["ratio"] = ratio
else: else:
ratio = 0.3 ratio = 0.3
if not isinstance(ratio, float) and not isinstance(ratio, int):
raise ValueError("Ratio should be an int or float.")
if not isinstance(transform, TensorOp): if not isinstance(transform, TensorOp):
raise ValueError("Transform can only be a C++ operation.") raise ValueError("Transform can only be a C++ operation.")
kwargs["transform"] = transform kwargs["transform"] = transform


+ 34
- 57
tests/ut/python/dataset/test_bounding_box_augment.py View File

@@ -16,7 +16,7 @@
Testing the bounding box augment op in DE Testing the bounding box augment op in DE
""" """
from enum import Enum from enum import Enum
from mindspore import log as logger
import mindspore.log as logger
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -39,59 +39,36 @@ class BoxType(Enum):
WrongShape = 5 WrongShape = 5




class AddBadAnnotation: # pylint: disable=too-few-public-methods
def add_bad_annotation(img, bboxes, box_type):
""" """
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:param box_type: type of bad box
:return: bboxes with bad examples added
""" """
height = img.shape[0]
width = img.shape[1]
if box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32)


def __init__(self, box_type):
self.box_type = box_type

def __call__(self, img, bboxes):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:return: bboxes with bad examples added
"""
height = img.shape[0]
width = img.shape[1]
if self.box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32)

if self.box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32)

if self.box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32)

if self.box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32)

if self.box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes


def h_flip(image):
"""
Apply the random_horizontal
"""
if box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32)

if box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32)

if box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32)


# with the seed provided in this test case, it will always flip.
# that's why we flip here too
image = image[:, ::-1, :]
return image
if box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes




def check_bad_box(data, box_type, expected_error): def check_bad_box(data, box_type, expected_error):
@@ -102,8 +79,8 @@ def check_bad_box(data, box_type, expected_error):
:return: None :return: None
""" """
try: try:
test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1),
1) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1),
1) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
data = data.map(input_columns=["annotation"], data = data.map(input_columns=["annotation"],
output_columns=["annotation"], output_columns=["annotation"],
operations=fix_annotate) operations=fix_annotate)
@@ -111,7 +88,7 @@ def check_bad_box(data, box_type, expected_error):
data = data.map(input_columns=["image", "annotation"], data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],
columns_order=["image", "annotation"], columns_order=["image", "annotation"],
operations=AddBadAnnotation(box_type)) # Add column for "annotation"
operations=lambda img, bboxes: add_bad_annotation(img, bboxes, box_type))
# map to apply ops # map to apply ops
data = data.map(input_columns=["image", "annotation"], data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],
@@ -187,7 +164,7 @@ def test_bounding_box_augment_with_rotation_op(plot=False):
data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)


test_op = c_vision.BoundingBoxAug(c_vision.RandomRotation(90), 1)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)


# maps to fix annotations to minddata standard # maps to fix annotations to minddata standard
@@ -216,7 +193,7 @@ def test_bounding_box_augment_with_crop_op(plot=False):
data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)


test_op = c_vision.BoundingBoxAug(c_vision.RandomCrop(90), 1)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(90), 1)


# maps to fix annotations to minddata standard # maps to fix annotations to minddata standard
data_voc1 = data_voc1.map(input_columns=["annotation"], data_voc1 = data_voc1.map(input_columns=["annotation"],
@@ -244,7 +221,7 @@ def test_bounding_box_augment_valid_ratio_c(plot=False):
data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)


test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1), 0.9)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
# DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)


# maps to fix annotations to minddata standard # maps to fix annotations to minddata standard
@@ -274,7 +251,7 @@ def test_bounding_box_augment_invalid_ratio_c():


try: try:
# ratio range is from 0 - 1 # ratio range is from 0 - 1
test_op = c_vision.BoundingBoxAug(c_vision.RandomHorizontalFlip(1), 1.5)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5)
# maps to fix annotations to minddata standard # maps to fix annotations to minddata standard
data_voc1 = data_voc1.map(input_columns=["annotation"], data_voc1 = data_voc1.map(input_columns=["annotation"],
output_columns=["annotation"], output_columns=["annotation"],


+ 28
- 43
tests/ut/python/dataset/test_random_horizontal_flip_bbox.py View File

@@ -16,12 +16,12 @@
Testing the random horizontal flip with bounding boxes op in DE Testing the random horizontal flip with bounding boxes op in DE
""" """
from enum import Enum from enum import Enum
from mindspore import log as logger
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as patches import matplotlib.patches as patches
import numpy as np import numpy as np
import mindspore.log as logger
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision


GENERATE_GOLDEN = False GENERATE_GOLDEN = False


@@ -38,57 +38,42 @@ class BoxType(Enum):
OnEdge = 4 OnEdge = 4
WrongShape = 5 WrongShape = 5



class AddBadAnnotation: # pylint: disable=too-few-public-methods
def add_bad_annotation(img, bboxes, box_type):
""" """
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:param box_type: type of bad box
:return: bboxes with bad examples added
""" """
height = img.shape[0]
width = img.shape[1]
if box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32)


def __init__(self, box_type):
self.box_type = box_type

def __call__(self, img, bboxes):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:return: bboxes with bad examples added
"""
height = img.shape[0]
width = img.shape[1]
if self.box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32)
if box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32)


if self.box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32)
if box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32)


if self.box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32)
if box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32)


if self.box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32)

if self.box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes
if box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes




def h_flip(image): def h_flip(image):
""" """
Apply the random_horizontal Apply the random_horizontal
""" """

# with the seed provided in this test case, it will always flip.
# that's why we flip here too # that's why we flip here too
image = image[:, ::-1, :] image = image[:, ::-1, :]
return image return image
@@ -111,7 +96,7 @@ def check_bad_box(data, box_type, expected_error):
data = data.map(input_columns=["image", "annotation"], data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],
columns_order=["image", "annotation"], columns_order=["image", "annotation"],
operations=AddBadAnnotation(box_type)) # Add column for "annotation"
operations=lambda img, bboxes: add_bad_annotation(img, bboxes, box_type))
# map to apply ops # map to apply ops
data = data.map(input_columns=["image", "annotation"], data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"], output_columns=["image", "annotation"],


Loading…
Cancel
Save