diff --git a/mindspore/lite/nnacl/base/arithmetic_base.c b/mindspore/lite/nnacl/base/arithmetic_base.c index 5a04ccc07d..8b5f154799 100644 --- a/mindspore/lite/nnacl/base/arithmetic_base.c +++ b/mindspore/lite/nnacl/base/arithmetic_base.c @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ #include "nnacl/base/arithmetic_base.h" void CalcMultiplesAndStrides(ArithmeticParameter *param) { - NNACL_ASSERT(param->in_shape0_[i] != 0); - NNACL_ASSERT(param->in_shape1_[i] != 0); for (size_t i = 0; i < param->ndim_; i++) { if (param->in_shape0_[i] != 0) { param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c index 97b2b5fbdc..c8aa9b92b4 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c @@ -52,7 +52,7 @@ void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_ param->multiples1_); } -int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -68,7 +68,7 @@ int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int return NNACL_OK; } -int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -101,7 +101,7 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i return NNACL_OK; } -int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); #endif @@ -122,7 +122,7 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -160,7 +160,7 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -179,7 +179,7 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -216,7 +216,7 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -238,7 +238,7 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int return NNACL_OK; } -int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -271,7 +271,7 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i return NNACL_OK; } -int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -298,7 +298,7 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -336,7 +336,7 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -364,7 +364,7 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -401,7 +401,7 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -417,7 +417,7 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int return NNACL_OK; } -int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -450,7 +450,7 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i return NNACL_OK; } -int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -469,7 +469,7 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -507,7 +507,7 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -526,7 +526,7 @@ int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -563,7 +563,7 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -580,7 +580,7 @@ int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int return NNACL_OK; } -int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -617,7 +617,7 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i return NNACL_OK; } -int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -640,7 +640,7 @@ int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -681,7 +681,7 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t zeros = vdupq_n_f16(0.0); @@ -703,7 +703,7 @@ int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -744,7 +744,7 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { for (int i = 0; i < element_size; ++i) { if (input1[i] == 0) { return NNACL_ERRCODE_DIVISOR_ZERO; @@ -754,7 +754,7 @@ int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { if (param->in_elements_num1_ == 1) { for (int i = 0; i < element_size; ++i) { @@ -770,14 +770,14 @@ int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { for (int i = 0; i < element_size; ++i) { NNACL_ASSERT(input1[i] != 0); output[i] = floorf(input0[i] / input1[i]); } return NNACL_OK; } -int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { if (param->in_elements_num1_ == 1) { for (int i = 0; i < element_size; ++i) { @@ -793,7 +793,7 @@ int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *outp return NNACL_OK; } -int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t vtrue = vdupq_n_f16(1); @@ -813,7 +813,7 @@ int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -854,7 +854,7 @@ int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *ou return NNACL_OK; } -int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON float16x8_t vtrue = vdupq_n_f16(1); @@ -874,7 +874,7 @@ int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output return NNACL_OK; } -int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -915,18 +915,19 @@ int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *out return NNACL_OK; } -int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size) { ElementSubFp16(input0, input1, output, element_size); return ElementMulFp16(output, output, output, element_size); } -int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param) { +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, ArithmeticParameter *param) { ElementOptSubFp16(input0, input1, output, element_size, param); return ElementMulFp16(output, output, output, element_size); } -int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -942,7 +943,7 @@ int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -975,7 +976,7 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -991,7 +992,7 @@ int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, return NNACL_OK; } -int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1024,7 +1025,7 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu return NNACL_OK; } -int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1040,7 +1041,7 @@ int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, i return NNACL_OK; } -int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1073,7 +1074,7 @@ int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output return NNACL_OK; } -int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1089,7 +1090,7 @@ int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int return NNACL_OK; } -int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1122,7 +1123,7 @@ int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, i return NNACL_OK; } -int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1138,7 +1139,7 @@ int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int e return NNACL_OK; } -int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1171,7 +1172,7 @@ int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, in return NNACL_OK; } -int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1187,7 +1188,7 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, return NNACL_OK; } -int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1220,7 +1221,7 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *outpu return NNACL_OK; } -int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1236,7 +1237,7 @@ int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, in return NNACL_OK; } -int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); @@ -1269,7 +1270,7 @@ int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, return NNACL_OK; } -int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) { +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { @@ -1285,7 +1286,7 @@ int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *outpu return NNACL_OK; } -int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param) { #ifdef ENABLE_NEON float16x8_t vin0_opt = vdupq_n_f16(input0[0]); diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h index 00d06f3856..10c34073e8 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.h @@ -32,92 +32,92 @@ void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, ArithmeticParameter *param); -int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param); -int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, ArithmeticParameter *param); +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, ArithmeticParameter *param); -int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, ArithmeticParameter *param); -int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, float16_t *out, int size, ArithmeticParameter *param); -int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); -int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h index 168678f079..98f300334a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h @@ -23,9 +23,10 @@ #include "schema/model_generated.h" namespace mindspore::kernel { -typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); -typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size, - ArithmeticParameter *param); +typedef int (*ArithmeticCompareFuncFp16)(const float16_t *input0, const float16_t *input1, uint8_t *output, + int element_size); +typedef int (*ArithmeticCompareOptFuncFp16)(const float16_t *input0, const float16_t *input1, uint8_t *output, + int element_size, ArithmeticParameter *param); typedef struct { int primitive_type_; int activation_type_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index ea522049d2..5ff560dc7b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,16 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" #include "src/runtime/kernel/arm/fp16/common_fp16.h" -#include "nnacl/fp16/arithmetic_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/runtime_api.h" -#include "include/errorcode.h" -#include "src/ops/arithmetic.h" +#include "nnacl/fp16/cast_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -49,73 +43,10 @@ using mindspore::schema::PrimitiveType_SquaredDifference; using mindspore::schema::PrimitiveType_Sub; namespace mindspore::kernel { -ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = { - {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, - {PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, - {PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16}, - {PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16}, - {PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, - {PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16}, - {PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16}, - {PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, - {PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16}, - {PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16}, - {PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, - {PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16}, - {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16}, - {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16}, - {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, - {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, - {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, - ElementOptSquaredDifferenceFp16}, - {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, - {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}}; - -ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) { - size_t length = sizeof(arithmetic_fun_table_fp16) / sizeof(ARITHMETIC_FUNC_INFO_FP16); - for (size_t i = 0; i < length; i++) { - if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && - arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { - return arithmetic_fun_table_fp16[i].func_; - } - } - return nullptr; -} - -ArithmeticOptFuncFp16 GetOptimizedArithmeticFun(int primitive_type, int activation_type) { - size_t length = sizeof(arithmetic_fun_table_fp16) / sizeof(ARITHMETIC_FUNC_INFO_FP16); - for (size_t i = 0; i < length; i++) { - if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type && - arithmetic_fun_table_fp16[i].activation_type_ == activation_type) { - return arithmetic_fun_table_fp16[i].opt_func_; - } - } - return nullptr; -} - -int ArithmeticFP16CPUKernel::Init() { - if (!InferShapeDone()) { - return RET_OK; - } - return ReSize(); -} - -void ArithmeticFP16CPUKernel::InitParam() { - auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_; - param_->broadcasting_ = arithmetic_lite_primitive->Broadcasting(); - param_->ndim_ = arithmetic_lite_primitive->NDims(); - - param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); - param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); - param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - memcpy(param_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), - reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); - memcpy(param_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), - reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); - memcpy(param_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), - reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); - - return; +int ArithmeticFP16CPUKernel::ReSize() { + auto ret = ArithmeticCPUKernel::ReSize(); + data_type_len_ = sizeof(float16_t); + return ret; } int ArithmeticFP16CPUKernel::CheckDataType() { @@ -130,131 +61,115 @@ int ArithmeticFP16CPUKernel::CheckDataType() { return RET_OK; } -int ArithmeticFP16CPUKernel::ReSize() { - if (CheckDataType() != RET_OK) { - MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed."; - return RET_ERROR; - } - - InitParam(); - - if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { - param_->broadcasting_ = false; - arithmetic_opt_func_ = GetOptimizedArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); - } else { - arithmetic_func_ = GetArithmeticFun(param_->op_parameter_.type_, param_->activation_type_); - } - if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) { - MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; - return RET_ERROR; - } - - if (param_->broadcasting_) { - outside_ = 1; - for (int i = param_->ndim_ - 1; i >= 0; --i) { - if (param_->in_shape0_[i] != param_->in_shape1_[i]) { - break_pos_ = i; - break; - } - outside_ *= param_->out_shape_[i]; +void ArithmeticFP16CPUKernel::InitRunFunction() { + ARITHMETIC_FUNC_INFO_FP16 fun_table[] = { + {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, + {PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, + {PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16}, + {PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16}, + {PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, + {PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16}, + {PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16}, + {PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, + {PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16}, + {PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16}, + {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16}, + {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16}, + {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, + {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, + {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16, + ElementOptSquaredDifferenceFp16}, + {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16}, + {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}}; + + size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP16); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == param_->op_parameter_.type_ && + fun_table[i].activation_type_ == param_->activation_type_) { + arithmetic_opt_func_ = fun_table[i].opt_func_; + arithmetic_func_ = fun_table[i].func_; + return; } - ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_); - ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_); - ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_); } - return RET_OK; } -int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, - int out_count, int cur_offset) { - if (dim > break_pos_) { - return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count); +int ArithmeticFP16CPUKernel::ConstTensorBroadCast() { + int ret; + if (in_tensors_[0]->data_c() != nullptr) { + ret = ConvertFp32TensorToFp16(in_tensors_[0], context_); + if (ret != RET_OK) { + return ret; + } } - for (int i = 0; i < param_->out_shape_[dim]; ++i) { - int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i; - int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i; - int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim], - output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset); + if (in_tensors_[1]->data_c() != nullptr) { + ret = ConvertFp32TensorToFp16(in_tensors_[1], context_); if (ret != RET_OK) { - return RET_ERROR; + return ret; } } - return RET_OK; + return ArithmeticCPUKernel::ConstTensorBroadCast(); } -int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { - int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_); - int cur_offset = stride_per_thread * task_id; - int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset) - : MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset); - if (cur_count <= 0) { - return RET_OK; - } +void ArithmeticFP16CPUKernel::TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, + const int *in_strides, const int *out_strides, const int *multiple) { + TileOneDimensionFp16(reinterpret_cast(in_data), reinterpret_cast(out_data), 0, ndim, + in_shape, in_strides, out_strides, multiple); +} +int ArithmeticFP16CPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) { int ret = RET_OK; - if (param_->broadcasting_) { - ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset); - } else if (param_->in_elements_num0_ == 1) { - ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_); - } else if (param_->in_elements_num1_ == 1) { - ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_); - } else { - ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count); - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret; + if (in_tensors_[0]->data_type() != kNumberTypeFloat16) { + MS_LOG(ERROR) << "data type is not fp16"; + return RET_ERROR; } - return ret; -} - -static int ArithmeticsRunFp16(void *cdata, int task_id) { - auto arithmetic_kernel = reinterpret_cast(cdata); - auto ret = arithmetic_kernel->DoArithmetic(task_id); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]"; + if (is_opt) { + CHECK_NULL_RETURN(arithmetic_opt_func_, RET_ERROR); + ret = arithmetic_opt_func_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size, param_); + } else { + CHECK_NULL_RETURN(arithmetic_func_, RET_ERROR); + ret = arithmetic_func_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size); } return ret; } int ArithmeticFP16CPUKernel::Run() { + if (!input0_broadcast_) { + input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + } + if (!input1_broadcast_) { + input1_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); + } auto output_tensor = out_tensors_.at(0); - is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32; - is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32; - is_output_fp32_ = output_tensor->data_type() == kNumberTypeFloat32; - - input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); - input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_); - output_fp16_ = MallocOutputFp16(output_tensor, context_); - if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) { - MS_LOG(ERROR) << "Memory allocation failed"; - FreeTmpBuffer(); + output_ptr_ = MallocOutputFp16(output_tensor, context_); + if (input0_ptr_ == nullptr || input1_ptr_ == nullptr || output_ptr_ == nullptr) { + FreeFp16Buffer(); return RET_ERROR; } - - auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]"; - } - if (is_output_fp32_) { - Float16ToFloat32(output_fp16_, reinterpret_cast(output_tensor->MutableData()), + auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, context_->thread_num_); + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + Float16ToFloat32(static_cast(output_ptr_), reinterpret_cast(output_tensor->MutableData()), output_tensor->ElementsNum()); } - FreeTmpBuffer(); + FreeFp16Buffer(); return ret; } -void ArithmeticFP16CPUKernel::FreeTmpBuffer() { - if (is_input0_fp32_) { - context_->allocator->Free(input0_fp16_); - input0_fp16_ = nullptr; +void ArithmeticFP16CPUKernel::FreeFp16Buffer() { + if (!input0_broadcast_ && in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(input0_ptr_); + input0_ptr_ = nullptr; } - if (is_input1_fp32_) { - context_->allocator->Free(input1_fp16_); - input1_fp16_ = nullptr; + if (!input1_broadcast_ && in_tensors_.at(1)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(input1_ptr_); + input1_ptr_ = nullptr; } - if (is_output_fp32_) { - context_->allocator->Free(output_fp16_); - output_fp16_ = nullptr; + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(output_ptr_); + output_ptr_ = nullptr; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index a36de76573..bec8a0d620 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_ #include -#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" #include "nnacl/fp16/arithmetic_fp16.h" -#include "schema/model_generated.h" namespace mindspore::kernel { -typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size, - ArithmeticParameter *param); +typedef int (*ArithmeticFuncFp16)(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size); +typedef int (*ArithmeticOptFuncFp16)(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, ArithmeticParameter *param); typedef struct { int primitive_type_; int activation_type_; @@ -33,36 +32,24 @@ typedef struct { ArithmeticOptFuncFp16 opt_func_; } ARITHMETIC_FUNC_INFO_FP16; -class ArithmeticFP16CPUKernel : public LiteKernel { +class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel { public: ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - param_ = reinterpret_cast(parameter); - } + : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ArithmeticFP16CPUKernel() = default; - - int Init() override; int ReSize() override; int Run() override; - int CheckDataType(); - int DoArithmetic(int task_id); - int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, - int out_thread_stride); private: - void InitParam(); - void FreeTmpBuffer(); - int outside_; - int break_pos_; - bool is_input0_fp32_ = false; - bool is_input1_fp32_ = false; - bool is_output_fp32_ = false; - float16_t *input0_fp16_ = nullptr; - float16_t *input1_fp16_ = nullptr; - float16_t *output_fp16_ = nullptr; - ArithmeticParameter *param_ = nullptr; + void InitRunFunction() override; + int CheckDataType() override; + int ConstTensorBroadCast() override; + void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, const int *in_strides, + const int *out_strides, const int *multiple) override; + int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) override; + void FreeFp16Buffer(); ArithmeticFuncFp16 arithmetic_func_ = nullptr; ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc index 5e43a5b0d1..5a2fc01a0b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "src/runtime/kernel/arm/fp16/common_fp16.h" #include "nnacl/fp16/cast_fp16.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; namespace mindspore::kernel { float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext *ctx) { @@ -52,23 +55,20 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx) return fp16_data; } -bool IsExistFp16Tensor(const std::vector &inputs, const std::vector &outputs) { - bool result = false; - for (auto &input : inputs) { - if (input->data_type() == kNumberTypeFloat16) { - result = true; - break; - } +int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx) { + if (tensor->data_type() == TypeId::kNumberTypeFloat16) { + return RET_OK; } - if (result) { - return true; - } - for (auto &output : outputs) { - if (output->data_type() == kNumberTypeFloat16) { - result = true; - break; - } + auto fp32_data = tensor->data_c(); + tensor->set_data(nullptr); + tensor->set_data_type(TypeId::kNumberTypeFloat16); + auto ret = tensor->MallocData(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "malloc data failed"; + return RET_ERROR; } - return result; + Float32ToFloat16(static_cast(fp32_data), static_cast(tensor->data_c()), tensor->ElementsNum()); + ctx->allocator->Free(fp32_data); + return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h index 88111fdca1..184323f74b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx); -bool IsExistFp16Tensor(const std::vector &inputs, const std::vector &outputs); +int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index f791787e4b..2dc773a3e6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -31,7 +31,7 @@ namespace mindspore::kernel { int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) { if (dim > break_pos_) { - if (data_type_ == kDataTypeInt) { + if (in_tensors_[0]->data_type() == kNumberTypeInt) { return func_int32_(reinterpret_cast(input0) + out_thread_stride, reinterpret_cast(input1) + out_thread_stride, reinterpret_cast(output) + out_thread_stride, out_count); @@ -40,20 +40,20 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o reinterpret_cast(input1) + out_thread_stride, reinterpret_cast(output) + out_thread_stride, out_count); } - for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { - int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; - int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; + for (int i = 0; i < param_->out_shape_[dim]; ++i) { + int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; + int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; int error_code; - if (data_type_ == kDataTypeInt) { - error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], - reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], - reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], - dim + 1, out_count, out_thread_stride); + if (in_tensors_[0]->data_type() == kNumberTypeInt) { + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * param_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * param_->in_strides1_[dim], + reinterpret_cast(output) + i * param_->out_strides_[dim], dim + 1, out_count, + out_thread_stride); } else { - error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], - reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], - reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], - dim + 1, out_count, out_thread_stride); + error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * param_->in_strides0_[dim], + reinterpret_cast(input1) + pos1_ * param_->in_strides1_[dim], + reinterpret_cast(output) + i * param_->out_strides_[dim], dim + 1, out_count, + out_thread_stride); } if (error_code != RET_OK) { return error_code; @@ -65,8 +65,8 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { auto element_num = out_tensors_[0]->ElementsNum(); - MS_ASSERT(thread_count_ != 0); - int stride = UP_DIV(element_num, thread_count_); + MS_ASSERT(context_->thread_num_ != 0); + int stride = UP_DIV(element_num, context_->thread_num_); int count = MSMIN(stride, element_num - stride * task_id); if (count <= 0) { return RET_OK; @@ -78,14 +78,14 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { } int error_code; - if (arithmeticParameter_->broadcasting_) { // need broadcast - stride = UP_DIV(outside_, thread_count_); + if (param_->broadcasting_) { // need broadcast + stride = UP_DIV(outside_, context_->thread_num_); int out_count = MSMIN(stride, outside_ - stride * task_id); int out_thread_stride = stride * task_id; if (out_count <= 0) { return RET_OK; } - if (data_type_ == kDataTypeFloat) { + if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); @@ -95,7 +95,7 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } } else { // no broadcast, neither is scalar, two same shape - if (data_type_ == kDataTypeFloat) { + if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { error_code = func_fp32_(reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_) + stride * task_id, reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 41fabf2292..fca1e66fe2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,13 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" -#include "include/errorcode.h" -#include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/int8/add_int8.h" -#include "src/runtime/runtime_api.h" #include "src/ops/arithmetic.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -29,79 +24,119 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Eltwise; namespace mindspore::kernel { -ArithmeticCPUKernel::~ArithmeticCPUKernel() { - FreeTmpPtr(); - return; -} - int ArithmeticCPUKernel::Init() { + InitRunFunction(); if (!InferShapeDone()) { return RET_OK; } return ReSize(); } -int ArithmeticCPUKernel::InitBroadCastCase() { - /* if const node need broadcast - * and all need-broadcast-node are const - * broadcast in resize */ +int ArithmeticCPUKernel::ReSize() { + if (CheckDataType() != RET_OK) { + MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed."; + return RET_ERROR; + } + auto prim = reinterpret_cast(primitive_); + param_->broadcasting_ = prim->Broadcasting(); + param_->ndim_ = prim->NDims(); + param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + memcpy(param_->in_shape0_, prim->InShape0().data(), prim->InShape0().size() * sizeof(int)); + memcpy(param_->in_shape1_, prim->InShape1().data(), prim->InShape1().size() * sizeof(int)); + memcpy(param_->out_shape_, prim->OutputShape().data(), prim->OutputShape().size() * sizeof(int)); + CalcMultiplesAndStrides(param_); + if (param_->broadcasting_) { + outside_ = 1; + for (auto i = param_->ndim_ - 1; i >= 0; --i) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { + break_pos_ = i; + break; + } + outside_ *= param_->out_shape_[i]; + } + } + return ConstTensorBroadCast(); +} - if (!arithmeticParameter_->broadcasting_) { - return RET_OK; +int ArithmeticCPUKernel::CheckDataType() { + auto in0_dataType = in_tensors_.at(0)->data_type(); + auto in1_dataType = in_tensors_.at(1)->data_type(); + if (in0_dataType != in1_dataType) { + MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same."; + return RET_ERROR; } + return RET_OK; +} +int ArithmeticCPUKernel::ConstTensorBroadCast() { + /* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */ + if (!param_->broadcasting_) { + return RET_OK; + } if (out_tensors_[0]->Size() < 0) { return RET_OK; } - - if (arithmeticParameter_->in_elements_num0_ != arithmeticParameter_->out_elements_num_ && - arithmeticParameter_->in_elements_num1_ != arithmeticParameter_->out_elements_num_) { - /* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2] - * need broadcast both input */ + /* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2], need broadcast both input */ + if (param_->in_elements_num0_ != param_->out_elements_num_ && + param_->in_elements_num1_ != param_->out_elements_num_) { return RET_OK; } - - if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) && + if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) { /* run opt function * one of input is scalar */ return RET_OK; } - FreeTmpPtr(); - - CalcMultiplesAndStrides(arithmeticParameter_); - - if (in_tensors_[0]->data_c() != nullptr && - arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) { - input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float)); + FreeConstTileBuff(); + if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) { + input0_ptr_ = malloc(param_->out_elements_num_ * data_type_len_); if (input0_ptr_ == nullptr) { return RET_ERROR; } - TileOneDimensionFp32(reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(input0_ptr_), 0, - arithmeticParameter_->ndim_, arithmeticParameter_->in_shape0_, - arithmeticParameter_->in_strides0_, arithmeticParameter_->out_strides_, - arithmeticParameter_->multiples0_); - arithmeticParameter_->broadcasting_ = false; + TileConstTensor(in_tensors_[0]->data_c(), input0_ptr_, param_->ndim_, param_->in_shape0_, param_->in_strides0_, + param_->out_strides_, param_->multiples0_); input0_broadcast_ = true; + param_->in_elements_num0_ = param_->out_elements_num_; + param_->broadcasting_ = false; } - if (in_tensors_[1]->data_c() != nullptr && - arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) { - input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float)); + if (in_tensors_[1]->data_c() != nullptr && param_->in_elements_num1_ != param_->out_elements_num_) { + input1_ptr_ = malloc(param_->out_elements_num_ * data_type_len_); if (input1_ptr_ == nullptr) { - FreeTmpPtr(); + FreeConstTileBuff(); return RET_ERROR; } - TileOneDimensionFp32(reinterpret_cast(in_tensors_[1]->data_c()), reinterpret_cast(input1_ptr_), 0, - arithmeticParameter_->ndim_, arithmeticParameter_->in_shape1_, - arithmeticParameter_->in_strides1_, arithmeticParameter_->out_strides_, - arithmeticParameter_->multiples1_); - arithmeticParameter_->broadcasting_ = false; + TileConstTensor(in_tensors_[1]->data_c(), input1_ptr_, param_->ndim_, param_->in_shape1_, param_->in_strides1_, + param_->out_strides_, param_->multiples1_); input1_broadcast_ = true; + param_->in_elements_num1_ = param_->out_elements_num_; + param_->broadcasting_ = false; } return RET_OK; } +void ArithmeticCPUKernel::TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, + const int *in_strides, const int *out_strides, const int *multiple) { + TileOneDimensionFp32(reinterpret_cast(in_data), reinterpret_cast(out_data), 0, ndim, in_shape, + in_strides, out_strides, multiple); +} + +void ArithmeticCPUKernel::FreeConstTileBuff() { + if (input0_broadcast_ == true && input0_ptr_ != nullptr) { + free(input0_ptr_); + input0_ptr_ = nullptr; + input0_broadcast_ = false; + } + if (input1_broadcast_ == true && input1_ptr_ != nullptr) { + free(input1_ptr_); + input1_ptr_ = nullptr; + input0_broadcast_ = false; + } + return; +} + void ArithmeticCPUKernel::InitRunFunction() { ARITHMETIC_FUNC_INFO_FP32 fun_table[] = { {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu, @@ -146,8 +181,8 @@ void ArithmeticCPUKernel::InitRunFunction() { size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32); for (size_t i = 0; i < length; i++) { - if (fun_table[i].primitive_type_ == op_parameter_->type_ && - fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) { + if (fun_table[i].primitive_type_ == param_->op_parameter_.type_ && + fun_table[i].activation_type_ == param_->activation_type_) { arithmetic_run_ = fun_table[i].func_; arithmetic_run_int_ = fun_table[i].int_func_; arithmetic_run_bool_ = fun_table[i].bool_func_; @@ -158,79 +193,53 @@ void ArithmeticCPUKernel::InitRunFunction() { } } -void ArithmeticCPUKernel::InitParam() { - auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_; - arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting(); - arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims(); - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { - data_type_ = kDataTypeFloat; +int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) { + int ret = RET_OK; + if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { + if (is_opt) { + CHECK_NULL_RETURN(arithmetic_opt_run_, RET_ERROR); + ret = arithmetic_opt_run_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size, param_); + } else { + CHECK_NULL_RETURN(arithmetic_run_, RET_ERROR); + ret = arithmetic_run_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size); + } } else if (in_tensors_[0]->data_type() == kNumberTypeBool) { - data_type_ = KDataTypeBool; + CHECK_NULL_RETURN(arithmetic_run_bool_, RET_ERROR); + ret = arithmetic_run_bool_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size); } else { - data_type_ = kDataTypeInt; - } - - arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); - arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); - arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), - reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); - memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), - reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); - memcpy(arithmeticParameter_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), - reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); - - return; -} - -int ArithmeticCPUKernel::CheckDataType() { - auto in0_dataType = in_tensors_.at(0)->data_type(); - auto in1_dataType = in_tensors_.at(1)->data_type(); - if (in0_dataType != in1_dataType) { - MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same."; - return RET_ERROR; - } - return RET_OK; -} - -int ArithmeticCPUKernel::ReSize() { - if (CheckDataType() != RET_OK) { - MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed."; - return RET_ERROR; + if (is_opt) { + CHECK_NULL_RETURN(arithmetic_opt_run_int_, RET_ERROR); + ret = arithmetic_opt_run_int_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size, param_); + } else { + CHECK_NULL_RETURN(arithmetic_run_int_, RET_ERROR); + ret = arithmetic_run_int_(reinterpret_cast(input0), reinterpret_cast(input1), + reinterpret_cast(output), size); + } } - InitParam(); - return InitBroadCastCase(); + return ret; } int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) { if (dim > break_pos_) { - if (data_type_ == kDataTypeInt) { - return arithmetic_run_int_(reinterpret_cast(input0) + out_thread_stride, - reinterpret_cast(input1) + out_thread_stride, - reinterpret_cast(output) + out_thread_stride, out_count); - } - return arithmetic_run_(reinterpret_cast(input0) + out_thread_stride, - reinterpret_cast(input1) + out_thread_stride, - reinterpret_cast(output) + out_thread_stride, out_count); + int offset = out_thread_stride * data_type_len_; + return Execute(static_cast(input0) + offset, static_cast(input1) + offset, + static_cast(output) + offset, out_count, false); } - for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { - int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; - int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; - int error_code; - if (data_type_ == kDataTypeInt) { - error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], - reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], - reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], dim + 1, - out_count, out_thread_stride); - } else { - error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], - reinterpret_cast(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], - reinterpret_cast(output) + i * arithmeticParameter_->out_strides_[dim], - dim + 1, out_count, out_thread_stride); - } - if (error_code != RET_OK) { - return error_code; + int offset_size[] = {param_->in_strides0_[dim] * data_type_len_, param_->in_strides1_[dim] * data_type_len_, + param_->out_strides_[dim] * data_type_len_}; + for (int i = 0; i < param_->out_shape_[dim]; ++i) { + int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; + int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; + int ret = BroadcastRun(static_cast(input0) + pos0_ * offset_size[0], + static_cast(input1) + pos1_ * offset_size[1], + static_cast(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride); + if (ret != RET_OK) { + return ret; } } return RET_OK; @@ -240,20 +249,20 @@ bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 if (input0_broadcast_ || input1_broadcast_) { return false; } - if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ || - arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { + if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 || + param_->in_elements_num1_ == 1) { return false; } size_t break_axis = 0; - for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) { - if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + for (size_t i = 0; i < param_->ndim_; i++) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { break_axis = i; break; } } - if (break_axis < arithmeticParameter_->ndim_) { - for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) { - if (arithmeticParameter_->in_shape1_[i] != 1) { + if (break_axis < param_->ndim_) { + for (size_t i = break_axis; i < param_->ndim_; i++) { + if (param_->in_shape1_[i] != 1) { return false; } } @@ -263,16 +272,19 @@ bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 } int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { - int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1]; - int batch_per_thread = UP_DIV(batch, thread_count_); + if (break_pos_ < 1) { + return RET_ERROR; + } + int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1]; + int batch_per_thread = UP_DIV(batch, context_->thread_num_); int start_batch = batch_per_thread * task_id; int end_batch = MSMIN(start_batch + batch_per_thread, batch); int batch_size = end_batch - start_batch; - int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1]; - int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1]; - int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1]; + int stride0 = param_->in_strides0_[break_pos_ - 1] * data_type_len_; + int stride1 = param_->in_strides1_[break_pos_ - 1] * data_type_len_; + int out_stride = param_->out_strides_[break_pos_ - 1] * data_type_len_; int offset0 = stride0 * start_batch; int offset1 = stride1 * start_batch; @@ -280,15 +292,8 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { int ret = RET_OK; for (int i = 0; i < batch_size; i++) { - if (data_type_ == kDataTypeFloat) { - ret = arithmetic_opt_run_( - reinterpret_cast(input0_ptr_) + offset0, reinterpret_cast(input1_ptr_) + offset1, - reinterpret_cast(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); - } else { - ret = arithmetic_opt_run_int_( - reinterpret_cast(input0_ptr_) + offset0, reinterpret_cast(input1_ptr_) + offset1, - reinterpret_cast(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); - } + ret = Execute(static_cast(input0_ptr_) + offset0, static_cast(input1_ptr_) + offset1, + static_cast(output_ptr_) + out_offset, param_->out_strides_[break_pos_ - 1], true); offset0 += stride0; offset1 += stride1; out_offset += out_stride; @@ -298,143 +303,61 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { int ArithmeticCPUKernel::DoArithmetic(int task_id) { auto element_num = out_tensors_[0]->ElementsNum(); - - MS_ASSERT(thread_count_ != 0); - int stride = UP_DIV(element_num, thread_count_); + int stride = UP_DIV(element_num, context_->thread_num_); int count = MSMIN(stride, element_num - stride * task_id); if (count <= 0) { return RET_OK; } - - if (arithmetic_run_ == nullptr) { - MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; - return RET_ERROR; - } + /* run opt function, every batch one of input is scalar */ if (CanBatchScalar()) { return BatchScalarCalc(task_id); } - int error_code = RET_OK; - if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) && + int offset = stride * task_id * data_type_len_; + /* run opt function, one of input is scalar */ + if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) { - /* run opt function - * one of input is scalar */ - if (arithmeticParameter_->in_elements_num0_ == 1) { - if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_opt_run_( - reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); - } else { - error_code = arithmetic_opt_run_int_( - reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); - } - } else if (arithmeticParameter_->in_elements_num1_ == 1) { - if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_opt_run_( - reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); - } else { - error_code = arithmetic_opt_run_int_( - reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); - } + if (param_->in_elements_num0_ == 1) { + return Execute(input0_ptr_, static_cast(input1_ptr_) + offset, + static_cast(output_ptr_) + offset, count, true); + } else if (param_->in_elements_num1_ == 1) { + return Execute(static_cast(input0_ptr_) + offset, input1_ptr_, + static_cast(output_ptr_) + offset, count, true); } - return error_code; } - if (arithmeticParameter_->broadcasting_) { - /* need broadcast in runtime */ - stride = UP_DIV(outside_, thread_count_); + /* need broadcast in runtime */ + if (param_->broadcasting_) { + stride = UP_DIV(outside_, context_->thread_num_); int out_count = MSMIN(stride, outside_ - stride * task_id); if (out_count <= 0) { return RET_OK; } - int out_thread_stride = stride * task_id; - if (data_type_ == kDataTypeFloat) { - error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); - } else { - error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); - } - return error_code; + return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id); } - /* no broadcast in runtime */ - if (data_type_ == kDataTypeFloat) { - error_code = arithmetic_run_(reinterpret_cast(input0_ptr_) + stride * task_id, - reinterpret_cast(input1_ptr_) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); - } else if (data_type_ == KDataTypeBool) { - error_code = arithmetic_run_bool_(reinterpret_cast(input0_ptr_) + stride * task_id, - reinterpret_cast(input1_ptr_) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); - } else { - error_code = arithmetic_run_int_(reinterpret_cast(input0_ptr_) + stride * task_id, - reinterpret_cast(input1_ptr_) + stride * task_id, - reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count); - } - return error_code; + return Execute(static_cast(input0_ptr_) + offset, static_cast(input1_ptr_) + offset, + static_cast(output_ptr_) + offset, count, false); } int ArithmeticsRun(void *cdata, int task_id) { - auto arithmetic_kernel = reinterpret_cast(cdata); - auto error_code = arithmetic_kernel->DoArithmetic(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoArithmetic(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << ret << "]"; } - return RET_OK; -} - -void ArithmeticCPUKernel::FreeTmpPtr() { - if (input0_broadcast_ == true && input0_ptr_ != nullptr) { - free(input0_ptr_); - input0_ptr_ = nullptr; - input0_broadcast_ = false; - } - if (input1_broadcast_ == true && input1_ptr_ != nullptr) { - free(input1_ptr_); - input1_ptr_ = nullptr; - input0_broadcast_ = false; - } - return; + return ret; } -void ArithmeticCPUKernel::InitParamInRunTime() { - /* after infershape */ - if (arithmeticParameter_->broadcasting_) { - outside_ = 1; - for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { - if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { - break_pos_ = i; - break; - } - outside_ *= arithmeticParameter_->out_shape_[i]; - } - } - ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); - ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); - +int ArithmeticCPUKernel::Run() { if (!input0_broadcast_) { input0_ptr_ = in_tensors_[0]->data_c(); } if (!input1_broadcast_) { input1_ptr_ = in_tensors_[1]->data_c(); } - return; + output_ptr_ = out_tensors_[0]->data_c(); + return ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, context_->thread_num_); } -int ArithmeticCPUKernel::Run() { - InitParamInRunTime(); - - int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mul, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 11155bd99c..5be91cbe40 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,14 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ #include #include "src/lite_kernel.h" #include "nnacl/fp32/arithmetic_fp32.h" -#include "schema/model_generated.h" using mindspore::schema::PrimitiveType_Add; using mindspore::schema::PrimitiveType_Div; @@ -42,6 +40,14 @@ using mindspore::schema::PrimitiveType_RealDiv; using mindspore::schema::PrimitiveType_SquaredDifference; using mindspore::schema::PrimitiveType_Sub; +#define CHECK_NULL_RETURN(ptr, errcode) \ + do { \ + if (ptr == nullptr) { \ + MS_LOG(ERROR) << "ptr must not be null."; \ + return errcode; \ + } \ + } while (0); + namespace mindspore::kernel { class ArithmeticCPUKernel : public LiteKernel { typedef int (*ArithmeticRun)(const float *input0, const float *input1, float *output, const int element_size); @@ -66,11 +72,10 @@ class ArithmeticCPUKernel : public LiteKernel { ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { - arithmeticParameter_ = reinterpret_cast(parameter); - InitRunFunction(); + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); } - ~ArithmeticCPUKernel() override; + ~ArithmeticCPUKernel() { FreeConstTileBuff(); } int Init() override; int ReSize() override; @@ -78,33 +83,33 @@ class ArithmeticCPUKernel : public LiteKernel { virtual int DoArithmetic(int task_id); virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); - private: - void InitRunFunction(); - void InitParam(); - void FreeTmpPtr(); - int CheckDataType(); - int InitBroadCastCase(); - void InitParamInRunTime(); - bool CanBatchScalar(); - int BatchScalarCalc(int task_id); - protected: + virtual void InitRunFunction(); + virtual int CheckDataType(); + virtual int ConstTensorBroadCast(); + virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, + const int *in_strides, const int *out_strides, const int *multiple); + virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt); bool input0_broadcast_ = false; bool input1_broadcast_ = false; void *input0_ptr_ = nullptr; void *input1_ptr_ = nullptr; + void *output_ptr_ = nullptr; int break_pos_ = 0; int outside_ = 0; - int thread_count_ = 1; - ArithmeticParameter *arithmeticParameter_ = nullptr; - LiteDataType data_type_ = kDataTypeFloat; + ArithmeticParameter *param_ = nullptr; + int data_type_len_ = sizeof(float); private: + bool CanBatchScalar(); + int BatchScalarCalc(int task_id); + void FreeConstTileBuff(); ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; ArithmeticBoolRun arithmetic_run_bool_ = nullptr; }; +int ArithmeticsRun(void *cdata, int task_id); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_