Browse Source

!14160 [MSLITE] fix minumgrad bug

From: @zhengjun10
Reviewed-by: @HilbertDavid
Signed-off-by: @HilbertDavid
pull/14160/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
04349c27eb
6 changed files with 18 additions and 17 deletions
  1. +1
    -0
      mindspore/lite/examples/export_models/models_train.cfg
  2. +1
    -1
      mindspore/lite/micro/cmake/file_list.cmake
  3. +0
    -1
      mindspore/lite/nnacl/infer/arithmetic_grad_infer.c
  4. +5
    -4
      mindspore/lite/nnacl/infer/max_min_grad_infer.c
  5. +5
    -5
      mindspore/lite/nnacl/infer/max_min_grad_infer.h
  6. +6
    -6
      mindspore/lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc

+ 1
- 0
mindspore/lite/examples/export_models/models_train.cfg View File

@@ -12,3 +12,4 @@ densenet
shufflenetv2 shufflenetv2
vgg noarm32 vgg noarm32
xception xception
albert_mlm

+ 1
- 1
mindspore/lite/micro/cmake/file_list.cmake View File

@@ -236,7 +236,7 @@ set(LITE_KERNEL_SRC
${LITE_DIR}/nnacl/infer/lsh_projection_infer.c ${LITE_DIR}/nnacl/infer/lsh_projection_infer.c
${LITE_DIR}/nnacl/infer/lstm_infer.c ${LITE_DIR}/nnacl/infer/lstm_infer.c
${LITE_DIR}/nnacl/infer/matmul_infer.c ${LITE_DIR}/nnacl/infer/matmul_infer.c
${LITE_DIR}/nnacl/infer/maximum_grad_infer.c
${LITE_DIR}/nnacl/infer/max_min_grad_infer.c
${LITE_DIR}/nnacl/infer/mean_infer.c ${LITE_DIR}/nnacl/infer/mean_infer.c
${LITE_DIR}/nnacl/infer/pooling_grad_infer.c ${LITE_DIR}/nnacl/infer/pooling_grad_infer.c
${LITE_DIR}/nnacl/infer/pooling_infer.c ${LITE_DIR}/nnacl/infer/pooling_infer.c


+ 0
- 1
mindspore/lite/nnacl/infer/arithmetic_grad_infer.c View File

@@ -103,4 +103,3 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T


REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape) REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape)
REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape) REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape)
REG_INFER(MinimumGrad, PrimType_MinimumGrad, ArithmeticGradInferShape)

mindspore/lite/nnacl/infer/maximum_grad_infer.c → mindspore/lite/nnacl/infer/max_min_grad_infer.c View File

@@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */


#include "nnacl/infer/maximum_grad_infer.h"
#include "nnacl/infer/max_min_grad_infer.h"
#include "nnacl/arithmetic.h" #include "nnacl/arithmetic.h"
#include "nnacl/infer/infer_register.h" #include "nnacl/infer/infer_register.h"


int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug #ifdef Debug
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2);
if (check_ret != NNACL_OK) { if (check_ret != NNACL_OK) {
@@ -60,4 +60,5 @@ int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
return NNACL_OK; return NNACL_OK;
} }


REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaximumGradInferShape)
REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaxMinGradInferShape)
REG_INFER(MinimumGrad, PrimType_MinimumGrad, MaxMinGradInferShape)

mindspore/lite/nnacl/infer/maximum_grad_infer.h → mindspore/lite/nnacl/infer/max_min_grad_infer.h View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_
#define MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_
#ifndef MINDSPORE_LITE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_
#define MINDSPORE_LITE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_


#include "nnacl/infer/common_infer.h" #include "nnacl/infer/common_infer.h"


@@ -22,10 +22,10 @@
extern "C" { extern "C" {
#endif #endif


int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);


#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif // MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_
#endif // MINDSPORE_LITE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_

mindspore/lite/test/ut/nnacl/infer/maximum_grad_infer_test.cc → mindspore/lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc View File

@@ -14,17 +14,17 @@
* limitations under the License. * limitations under the License.
*/ */
#include "common/common_test.h" #include "common/common_test.h"
#include "mindspore/lite/nnacl/infer/maximum_grad_infer.h"
#include "mindspore/lite/nnacl/infer/max_min_grad_infer.h"
#include "mindspore/lite/nnacl/arithmetic.h" #include "mindspore/lite/nnacl/arithmetic.h"


namespace mindspore { namespace mindspore {


class MaximumGradInferTest : public mindspore::CommonTest {
class MaxMinGradInferTest : public mindspore::CommonTest {
public: public:
MaximumGradInferTest() {}
MaxMinGradInferTest() {}
}; };


TEST_F(MaximumGradInferTest, MaximumGradInferTest0) {
TEST_F(MaxMinGradInferTest, MaxMinGradInferTest0) {
size_t inputs_size = 3; size_t inputs_size = 3;
std::vector<TensorC *> inputs(inputs_size, NULL); std::vector<TensorC *> inputs(inputs_size, NULL);
inputs[0] = new TensorC; inputs[0] = new TensorC;
@@ -47,8 +47,8 @@ TEST_F(MaximumGradInferTest, MaximumGradInferTest0) {
outputs[1] = new TensorC; outputs[1] = new TensorC;
ArithmeticParameter *parameter = new ArithmeticParameter; ArithmeticParameter *parameter = new ArithmeticParameter;
parameter->op_parameter_.infer_flag_ = true; parameter->op_parameter_.infer_flag_ = true;
int ret = MaximumGradInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(),
reinterpret_cast<OpParameter *>(parameter));
int ret = MaxMinGradInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(),
reinterpret_cast<OpParameter *>(parameter));
ASSERT_EQ(ret, NNACL_OK); ASSERT_EQ(ret, NNACL_OK);
ASSERT_EQ(outputs[0]->shape_size_, 2); ASSERT_EQ(outputs[0]->shape_size_, 2);
ASSERT_EQ(outputs[0]->shape_[0], 4); ASSERT_EQ(outputs[0]->shape_[0], 4);

Loading…
Cancel
Save