diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 3ae5fe3b98..8fb3edac2a 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -290,7 +290,8 @@ void bindTensorOps1(py::module *m) { (void)py::class_>( *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") - .def(py::init(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps); + .def(py::init>, int32_t>(), py::arg("operations"), + py::arg("NumOps") = UniformAugOp::kDefNumOps); (void)py::class_>( *m, "ResizeBilinearOp", diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc index 0ed8140a08..147955ebac 100644 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc @@ -13,23 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/py_func_op.h" #include "dataset/util/random.h" namespace mindspore { namespace dataset { const int UniformAugOp::kDefNumOps = 2; -UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) { - std::shared_ptr tensor_op; - // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ - for (auto op : op_list) { - // only C++ op is accepted - tensor_op = op.cast>(); - tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); - } - +UniformAugOp::UniformAugOp(std::vector> op_list, int32_t num_ops) + : tensor_op_list_(op_list), num_ops_(num_ops) { rnd_.seed(GetSeed()); } @@ -38,37 +31,28 @@ Status UniformAugOp::Compute(const std::vector> &input, std::vector> *output) { IO_CHECK_VECTOR(input, output); - // variables to copy the result to output if it is not already - std::vector> even_out; - std::vector> *even_out_ptr = &even_out; - int count = 1; - // randomly select ops to be applied std::vector> selected_tensor_ops; std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); - for (auto tensor_op = selected_tensor_ops.begin(); tensor_op != selected_tensor_ops.end(); ++tensor_op) { + bool first = true; + for (const auto &tensor_op : selected_tensor_ops) { // Do NOT apply the op, if second random generator returned zero if (std::uniform_int_distribution(0, 1)(rnd_)) { continue; } - // apply C++ ops (note: python OPs are not accepted) - if (count == 1) { - RETURN_IF_NOT_OK((**tensor_op).Compute(input, output)); - } else if (count % 2 == 0) { - RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr)); + if (first) { + RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); + first = false; } else { - RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output)); + RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); } - count++; } - // copy the result to output if it is not in output - if (count == 1) { + // The case where no tensor op is applied. + if (output->empty()) { *output = input; - } else if ((count % 2 == 1)) { - (*output).swap(even_out); } return Status::OK(); diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h index a70edc2777..605f510746 100644 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h @@ -24,9 +24,6 @@ #include "dataset/core/tensor.h" #include "dataset/kernels/tensor_op.h" #include "dataset/util/status.h" -#include "dataset/kernels/py_func_op.h" - -#include "pybind11/stl.h" namespace mindspore { namespace dataset { @@ -36,10 +33,11 @@ class UniformAugOp : public TensorOp { static const int kDefNumOps; // Constructor for UniformAugOp - // @param list op_list: list of candidate C++ operations - // @param list num_ops: number of augemtation operations to applied - UniformAugOp(py::list op_list, int32_t num_ops); + // @param std::vector> op_list: list of candidate C++ operations + // @param int32_t num_ops: number of augemtation operations to applied + UniformAugOp(std::vector> op_list, int32_t num_ops); + // Destructor ~UniformAugOp() override = default; void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; }