From 0c21548b430cbfbedb2868b2f4a65b53e30f25fd Mon Sep 17 00:00:00 2001 From: wang_shaocong Date: Thu, 11 Mar 2021 16:13:48 +0800 Subject: [PATCH] [MSLITE] fix bug of reduce infershape --- mindspore/lite/nnacl/infer/reduce_infer.c | 64 ++++++++++++++--------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/mindspore/lite/nnacl/infer/reduce_infer.c b/mindspore/lite/nnacl/infer/reduce_infer.c index fbda6901e5..d712437568 100644 --- a/mindspore/lite/nnacl/infer/reduce_infer.c +++ b/mindspore/lite/nnacl/infer/reduce_infer.c @@ -16,6 +16,39 @@ #include "nnacl/infer/reduce_infer.h" +int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) { + if (keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + ShapePush(out_shape, &out_shape_size, 1); + } + } + SetShapeArray(output, out_shape, out_shape_size); + output->data_type_ = input->data_type_; + return NNACL_OK; +} + +int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, int *actual_axes, TensorC *output, int *out_shape, + size_t out_shape_size, bool keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + bool reduce_axis = false; + for (size_t idx = 0; idx < num_axes; ++idx) { + if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx] + input->shape_size_) == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + ShapePush(out_shape, &out_shape_size, 1); + } + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { #ifdef Debug @@ -37,6 +70,9 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * size_t out_shape_size = 0; // get axes from input tensor const TensorC *axes_input = inputs[1]; + if (axes_input->shape_size_ == 1 && axes_input->shape_[0] == 0) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } int *axes = (int *)axes_input->data_; if (axes == NULL) { return NNACL_NULL_PTR; @@ -70,32 +106,8 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } // reduce on all axes if (num_axes == 0) { - if (keep_dims) { - for (size_t i = 0; i < input->shape_size_; i++) { - ShapePush(out_shape, &out_shape_size, 1); - } - } - SetShapeArray(output, out_shape, out_shape_size); - output->data_type_ = input->data_type_; - return NNACL_OK; + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); } // reduce on selected axes - for (size_t i = 0; i < input->shape_size_; i++) { - bool reduce_axis = false; - for (size_t idx = 0; idx < num_axes; ++idx) { - if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx] + input->shape_size_) == i) { - reduce_axis = true; - break; - } - } - if (reduce_axis) { - if (keep_dims) { - ShapePush(out_shape, &out_shape_size, 1); - } - } else { - ShapePush(out_shape, &out_shape_size, input->shape_[i]); - } - } - SetShapeArray(output, out_shape, out_shape_size); - return NNACL_OK; + return ReduceOnSelectedAxes(input, num_axes, actual_axes, output, out_shape, out_shape_size, keep_dims); }