Browse Source

Added float32 support for CutMixBatch

tags/v1.0.0
Mahdi 5 years ago
parent
commit
a2c38d89f9
5 changed files with 42 additions and 17 deletions
  1. +3
    -3
      mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc
  2. +21
    -9
      mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc
  3. +11
    -0
      mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h
  4. +4
    -4
      mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc
  5. +3
    -1
      tests/ut/python/dataset/test_cutmix_batch_op.py

+ 3
- 3
mindspore/ccsrc/minddata/dataset/kernels/image/cutmix_batch_op.cc View File

@@ -50,7 +50,7 @@ void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y,


Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {
if (input.size() < 2) { if (input.size() < 2) {
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation.");
} }


std::vector<std::shared_ptr<Tensor>> images; std::vector<std::shared_ptr<Tensor>> images;
@@ -59,10 +59,10 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) {


// Check inputs // Check inputs
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling CutMixBatch.");
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batched before calling CutMixBatch.");
} }
if (label_shape.size() != 2) { if (label_shape.size() != 2) {
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch");
RETURN_STATUS_UNEXPECTED("CutMixBatch: Label's must be in one-hot format and in a batch.");
} }
if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) { if ((image_shape[1] != 1 && image_shape[1] != 3) && image_batch_format_ == ImageBatchFormat::kNCHW) {
RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format."); RETURN_STATUS_UNEXPECTED("CutMixBatch: Image doesn't match the given image format.");


+ 21
- 9
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc View File

@@ -415,9 +415,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
for (int i = 0; i < crop_width; i++) { for (int i = 0; i < crop_width; i++) {
for (int j = 0; j < crop_height; j++) { for (int j = 0; j < crop_height; j++) {
for (int c = 0; c < number_of_channels; c++) { for (int c = 0; c < number_of_channels; c++) {
uint8_t pixel_value;
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i, c}));
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i, c}, pixel_value));
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {j, i, c}, {y + j, x + i, c}));
} }
} }
} }
@@ -432,9 +430,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
for (int i = 0; i < crop_width; i++) { for (int i = 0; i < crop_width; i++) {
for (int j = 0; j < crop_height; j++) { for (int j = 0; j < crop_height; j++) {
for (int c = 0; c < number_of_channels; c++) { for (int c = 0; c < number_of_channels; c++) {
uint8_t pixel_value;
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {c, j, i}));
RETURN_IF_NOT_OK((*input)->SetItemAt({c, y + j, x + i}, pixel_value));
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {c, j, i}, {c, y + j, x + i}));
} }
} }
} }
@@ -447,9 +443,7 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
} }
for (int i = 0; i < crop_width; i++) { for (int i = 0; i < crop_width; i++) {
for (int j = 0; j < crop_height; j++) { for (int j = 0; j < crop_height; j++) {
uint8_t pixel_value;
RETURN_IF_NOT_OK(sub_mat->GetItemAt(&pixel_value, {j, i}));
RETURN_IF_NOT_OK((*input)->SetItemAt({y + j, x + i}, pixel_value));
RETURN_IF_NOT_OK(CopyTensorValue(sub_mat, input, {j, i}, {y + j, x + i}));
} }
} }
} else { } else {
@@ -458,6 +452,24 @@ Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Te
return Status::OK(); return Status::OK();
} }


Status CopyTensorValue(const std::shared_ptr<Tensor> &source_tensor, std::shared_ptr<Tensor> *dest_tensor,
const std::vector<int64_t> &source_indx, const std::vector<int64_t> &dest_indx) {
if (source_tensor->type() != (*dest_tensor)->type())
RETURN_STATUS_UNEXPECTED("CopyTensorValue: source and destination tensor must have the same type.");
if (source_tensor->type() == DataType::DE_UINT8) {
uint8_t pixel_value;
RETURN_IF_NOT_OK(source_tensor->GetItemAt(&pixel_value, source_indx));
RETURN_IF_NOT_OK((*dest_tensor)->SetItemAt(dest_indx, pixel_value));
} else if (source_tensor->type() == DataType::DE_FLOAT32) {
float pixel_value;
RETURN_IF_NOT_OK(source_tensor->GetItemAt(&pixel_value, source_indx));
RETURN_IF_NOT_OK((*dest_tensor)->SetItemAt(dest_indx, pixel_value));
} else {
RETURN_STATUS_UNEXPECTED("CopyTensorValue: Tensor type is not supported. Tensor type must be float32 or uint8.");
}
return Status::OK();
}

Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) { Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
try { try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));


+ 11
- 0
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h View File

@@ -133,6 +133,17 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output);
Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, int width, Status MaskWithTensor(const std::shared_ptr<Tensor> &sub_mat, std::shared_ptr<Tensor> *input, int x, int y, int width,
int height, ImageFormat image_format); int height, ImageFormat image_format);


/// \brief Copies a value from a source tensor into a destination tensor
/// \note This is meant for images and therefore only works if tensor is uint8 or float32
/// \param[in] source_tensor The tensor we take the value from
/// \param[in] dest_tensor The pointer to the tensor we want to copy the value to
/// \param[in] source_indx index of the value in the source tensor
/// \param[in] dest_indx index of the value in the destination tensor
/// \param[out] dest_tensor Copies the value to the given dest_tensor and returns it
/// @return Status ok/error
Status CopyTensorValue(const std::shared_ptr<Tensor> &source_tensor, std::shared_ptr<Tensor> *dest_tensor,
const std::vector<int64_t> &source_indx, const std::vector<int64_t> &dest_indx);

/// \brief Swap the red and blue pixels (RGB <-> BGR) /// \brief Swap the red and blue pixels (RGB <-> BGR)
/// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor. /// \param input: Tensor of shape <H,W,3> and any OpenCv compatible type, see CVTensor.
/// \param output: Swapped image of same shape and type /// \param output: Swapped image of same shape and type


+ 4
- 4
mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc View File

@@ -29,7 +29,7 @@ MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed());


Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
if (input.size() < 2) { if (input.size() < 2) {
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation.");
} }


std::vector<std::shared_ptr<CVTensor>> images; std::vector<std::shared_ptr<CVTensor>> images;
@@ -38,13 +38,13 @@ Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {


// Check inputs // Check inputs
if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) { if (image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batch before calling MixUpBatch");
RETURN_STATUS_UNEXPECTED("You must make sure images are HWC or CHW and batched before calling MixUpBatch.");
} }
if (label_shape.size() != 2) { if (label_shape.size() != 2) {
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch");
RETURN_STATUS_UNEXPECTED("MixUpBatch: Label's must be in one-hot format and in a batch.");
} }
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) { if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW.");
} }


// Move images into a vector of CVTensors // Move images into a vector of CVTensors


+ 3
- 1
tests/ut/python/dataset/test_cutmix_batch_op.py View File

@@ -76,7 +76,7 @@ def test_cutmix_batch_success1(plot=False):


def test_cutmix_batch_success2(plot=False): def test_cutmix_batch_success2(plot=False):
""" """
Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
""" """
logger.info("test_cutmix_batch_success2") logger.info("test_cutmix_batch_success2")


@@ -95,6 +95,8 @@ def test_cutmix_batch_success2(plot=False):
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10) one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op) data1 = data1.map(input_columns=["label"], operations=one_hot_op)
rescale_op = vision.Rescale((1.0/255.0), 0.0)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
data1 = data1.batch(5, drop_remainder=True) data1 = data1.batch(5, drop_remainder=True)
data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op) data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)


Loading…
Cancel
Save