Browse Source

!30429 BUG][MD][FUNC]RandomAutoContrast

Merge pull request !30429 from yangwm/autocontrast
r1.7
i-robot Gitee 4 years ago
parent
commit
8e8391dadb
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 83 additions and 1 deletions
  1. +15
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_auto_contrast_op.cc
  2. +68
    -0
      tests/ut/python/dataset/test_random_auto_contrast.py

+ 15
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_auto_contrast_op.cc View File

@@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/kernels/image/random_auto_contrast_op.h"

#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
@@ -26,6 +26,20 @@ const float RandomAutoContrastOp::kDefProbability = 0.5;

Status RandomAutoContrastOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// Check input
if (input->Rank() != DEFAULT_IMAGE_RANK) {
RETURN_STATUS_UNEXPECTED("RandomAutoContrast: image shape is not <H,W,C>, got rank: " +
std::to_string(input->Rank()));
}
if (input->shape()[CHANNEL_INDEX] != DEFAULT_IMAGE_CHANNELS) {
RETURN_STATUS_UNEXPECTED(
"RandomAutoContrast: image shape is incorrect, expected num of channels is 3, "
"but got:" +
std::to_string(input->shape()[CHANNEL_INDEX]));
}
CHECK_FAIL_RETURN_UNEXPECTED(input->type().AsCVType() != kCVInvalidType,
"RandomAutoContrast: Cannot convert from OpenCV type, unknown CV type. Currently "
"supported data type: [int8, uint8, int16, uint16, int32, float16, float32, float64].");
if (distribution_(rnd_)) {
return AutoContrast(input, output, cutoff_, ignore_);
}


+ 68
- 0
tests/ut/python/dataset/test_random_auto_contrast.py View File

@@ -177,6 +177,71 @@ def test_random_auto_contrast_invalid_cutoff():
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)


def test_random_auto_contrast_one_channel():
"""
Feature: RandomAutoContrast
Description: test with one channel images
Expectation: raise errors as expected
"""
logger.info("test_random_auto_contrast_one_channel")

c_op = c_vision.RandomAutoContrast()

try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])], input_columns=["image"])

data_set = data_set.map(operations=c_op, input_columns="image")

except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "image shape is incorrect, expected num of channels is 3." in str(e)


def test_random_auto_contrast_four_dim():
"""
Feature: RandomAutoContrast
Description: test with four dimension images
Expectation: raise errors as expected
"""
logger.info("test_random_auto_contrast_four_dim")

c_op = c_vision.RandomAutoContrast()

try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
lambda img: np.array(img[2, 200, 10, 32])], input_columns=["image"])

data_set = data_set.map(operations=c_op, input_columns="image")

except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "image shape is not <H,W,C>" in str(e)


def test_random_auto_contrast_invalid_input():
"""
Feature: RandomAutoContrast
Description: test with images in uint32 type
Expectation: raise errors as expected
"""
logger.info("test_random_invert_invalid_input")

c_op = c_vision.RandomAutoContrast()

try:
data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
data_set = data_set.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224)),
lambda img: np.array(img[2, 32, 3], dtype=uint32)], input_columns=["image"])
data_set = data_set.map(operations=c_op, input_columns="image")

except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Cannot convert from OpenCV type, unknown CV type" in str(e)


if __name__ == "__main__":
test_random_auto_contrast_pipeline(plot=True)
test_random_auto_contrast_eager()
@@ -184,3 +249,6 @@ if __name__ == "__main__":
test_random_auto_contrast_invalid_prob()
test_random_auto_contrast_invalid_ignore()
test_random_auto_contrast_invalid_cutoff()
test_random_auto_contrast_one_channel()
test_random_auto_contrast_four_dim()
test_random_auto_contrast_invalid_input()

Loading…
Cancel
Save