Browse Source

added input validation to reject python op in C++ uniform augmentation operations list

tags/v0.3.0-alpha
Adel Shafiei 5 years ago
parent
commit
d15bd04bfe
5 changed files with 81 additions and 13 deletions
  1. +4
    -8
      mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
  2. +1
    -1
      mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h
  3. +12
    -1
      mindspore/dataset/transforms/vision/c_transforms.py
  4. +2
    -2
      mindspore/dataset/transforms/vision/validators.py
  5. +62
    -1
      tests/ut/python/dataset/test_uniform_augment.py

+ 4
- 8
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc View File

@@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops
std::shared_ptr<TensorOp> tensor_op;
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
for (auto op : op_list) {
if (py::isinstance<py::function>(op)) {
// python op
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else if (py::isinstance<TensorOp>(op)) {
// C++ op
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
}
// only C++ op is accepted
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
}

rnd_.seed(GetSeed());
}

// compute method to apply uniformly random selected augmentations from a list
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
@@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
continue;
}

// apply python/C++ op
// apply C++ ops (note: python OPs are not accepted)
if (count == 1) {
(**tensor_op).Compute(input, output);
} else if (count % 2 == 0) {


+ 1
- 1
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h View File

@@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp {
static const int kDefNumOps;

// Constructor for UniformAugOp
// @param list op_list: list of candidate python operations
// @param list op_list: list of candidate C++ operations
// @param list num_ops: number of augemtation operations to applied
UniformAugOp(py::list op_list, int32_t num_ops);



+ 12
- 1
mindspore/dataset/transforms/vision/c_transforms.py View File

@@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp):
Tensor operation to perform randomly selected augmentation

Args:
operations: list of python operations.
operations: list of C++ operations (python OPs are not accepted).
NumOps (int): number of OPs to be selected and applied.

Examples:
>>> transforms_list = [c_transforms.RandomHorizontalFlip(),
>>> c_transforms.RandomVerticalFlip(),
>>> c_transforms.RandomColorAdjust(),
>>> c_transforms.RandomRotation(degrees=45)]
>>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2)
>>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]),
>>> uni_aug, F.ToTensor()]
>>> ds_ua = ds.map(input_columns="image",
>>> operations=transforms_all, num_parallel_workers=1)
"""

@check_uniform_augmentation


+ 2
- 2
mindspore/dataset/transforms/vision/validators.py View File

@@ -837,8 +837,8 @@ def check_uniform_augmentation(method):
if not isinstance(operations, list):
raise ValueError("operations is not a python list")
for op in operations:
if not callable(op) and not isinstance(op, TensorOp):
raise ValueError("non-callable op in operations list")
if not isinstance(op, TensorOp):
raise ValueError("operations list only accepts C++ operations.")

kwargs["num_ops"] = num_ops
kwargs["operations"] = operations


+ 62
- 1
tests/ut/python/dataset/test_uniform_augment.py View File

@@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2):
mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
logger.info("MSE= {}".format(str(np.mean(mse))))

def test_cpp_uniform_augment_exception_pyops(num_ops=2):
"""
Test UniformAugment invalid op in operations
"""
logger.info("Test CPP UniformAugment invalid OP exception")

transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45),
F.Invert()]

try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)

except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "operations" in str(e)

def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
"""
Test UniformAugment invalid large number of ops
"""
logger.info("Test CPP UniformAugment invalid large num_ops exception")

transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45)]

try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)

except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)

def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
"""
Test UniformAugment invalid non-positive number of ops
"""
logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")

transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
C.RandomHorizontalFlip(),
C.RandomVerticalFlip(),
C.RandomColorAdjust(),
C.RandomRotation(degrees=45)]

try:
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)

except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)

if __name__ == "__main__":
test_uniform_augment(num_ops=1)
test_cpp_uniform_augment(num_ops=1)
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)


Loading…
Cancel
Save