|
|
|
@@ -13,23 +13,16 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include <utility> |
|
|
|
#include "dataset/kernels/image/uniform_aug_op.h" |
|
|
|
#include "dataset/kernels/py_func_op.h" |
|
|
|
#include "dataset/util/random.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace dataset { |
|
|
|
const int UniformAugOp::kDefNumOps = 2; |
|
|
|
|
|
|
|
UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) { |
|
|
|
std::shared_ptr<TensorOp> tensor_op; |
|
|
|
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ |
|
|
|
for (auto op : op_list) { |
|
|
|
// only C++ op is accepted |
|
|
|
tensor_op = op.cast<std::shared_ptr<TensorOp>>(); |
|
|
|
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); |
|
|
|
} |
|
|
|
|
|
|
|
UniformAugOp::UniformAugOp(std::vector<std::shared_ptr<TensorOp>> op_list, int32_t num_ops) |
|
|
|
: tensor_op_list_(op_list), num_ops_(num_ops) { |
|
|
|
rnd_.seed(GetSeed()); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -38,37 +31,28 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, |
|
|
|
std::vector<std::shared_ptr<Tensor>> *output) { |
|
|
|
IO_CHECK_VECTOR(input, output); |
|
|
|
|
|
|
|
// variables to copy the result to output if it is not already |
|
|
|
std::vector<std::shared_ptr<Tensor>> even_out; |
|
|
|
std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out; |
|
|
|
int count = 1; |
|
|
|
|
|
|
|
// randomly select ops to be applied |
|
|
|
std::vector<std::shared_ptr<TensorOp>> selected_tensor_ops; |
|
|
|
std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); |
|
|
|
|
|
|
|
for (auto tensor_op = selected_tensor_ops.begin(); tensor_op != selected_tensor_ops.end(); ++tensor_op) { |
|
|
|
bool first = true; |
|
|
|
for (const auto &tensor_op : selected_tensor_ops) { |
|
|
|
// Do NOT apply the op, if second random generator returned zero |
|
|
|
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// apply C++ ops (note: python OPs are not accepted) |
|
|
|
if (count == 1) { |
|
|
|
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output)); |
|
|
|
} else if (count % 2 == 0) { |
|
|
|
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr)); |
|
|
|
if (first) { |
|
|
|
RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); |
|
|
|
first = false; |
|
|
|
} else { |
|
|
|
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output)); |
|
|
|
RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); |
|
|
|
} |
|
|
|
count++; |
|
|
|
} |
|
|
|
|
|
|
|
// copy the result to output if it is not in output |
|
|
|
if (count == 1) { |
|
|
|
// The case where no tensor op is applied. |
|
|
|
if (output->empty()) { |
|
|
|
*output = input; |
|
|
|
} else if ((count % 2 == 1)) { |
|
|
|
(*output).swap(even_out); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
|