Browse Source

!13172 [MSLITE] Fix bug of reduce.

From: @wang_shaocong
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
f3ad5b7d45
1 changed files with 38 additions and 26 deletions
  1. +38
    -26
      mindspore/lite/nnacl/infer/reduce_infer.c

+ 38
- 26
mindspore/lite/nnacl/infer/reduce_infer.c View File

@@ -16,6 +16,39 @@

#include "nnacl/infer/reduce_infer.h"

int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) {
if (keep_dims) {
for (size_t i = 0; i < input->shape_size_; i++) {
ShapePush(out_shape, &out_shape_size, 1);
}
}
SetShapeArray(output, out_shape, out_shape_size);
output->data_type_ = input->data_type_;
return NNACL_OK;
}

int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, int *actual_axes, TensorC *output, int *out_shape,
size_t out_shape_size, bool keep_dims) {
for (size_t i = 0; i < input->shape_size_; i++) {
bool reduce_axis = false;
for (size_t idx = 0; idx < num_axes; ++idx) {
if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx] + input->shape_size_) == i) {
reduce_axis = true;
break;
}
}
if (reduce_axis) {
if (keep_dims) {
ShapePush(out_shape, &out_shape_size, 1);
}
} else {
ShapePush(out_shape, &out_shape_size, input->shape_[i]);
}
}
SetShapeArray(output, out_shape, out_shape_size);
return NNACL_OK;
}

int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
@@ -37,6 +70,9 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
size_t out_shape_size = 0;
// get axes from input tensor
const TensorC *axes_input = inputs[1];
if (axes_input->shape_size_ == 1 && axes_input->shape_[0] == 0) {
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
}
int *axes = (int *)axes_input->data_;
if (axes == NULL) {
return NNACL_NULL_PTR;
@@ -70,32 +106,8 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
// reduce on all axes
if (num_axes == 0) {
if (keep_dims) {
for (size_t i = 0; i < input->shape_size_; i++) {
ShapePush(out_shape, &out_shape_size, 1);
}
}
SetShapeArray(output, out_shape, out_shape_size);
output->data_type_ = input->data_type_;
return NNACL_OK;
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
}
// reduce on selected axes
for (size_t i = 0; i < input->shape_size_; i++) {
bool reduce_axis = false;
for (size_t idx = 0; idx < num_axes; ++idx) {
if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx] + input->shape_size_) == i) {
reduce_axis = true;
break;
}
}
if (reduce_axis) {
if (keep_dims) {
ShapePush(out_shape, &out_shape_size, 1);
}
} else {
ShapePush(out_shape, &out_shape_size, input->shape_[i]);
}
}
SetShapeArray(output, out_shape, out_shape_size);
return NNACL_OK;
return ReduceOnSelectedAxes(input, num_axes, actual_axes, output, out_shape, out_shape_size, keep_dims);
}

Loading…
Cancel
Save