From 4327fd7e2069974c595c525c0562c8dda99b16f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E8=B4=B5?= Date: Fri, 7 Aug 2020 16:10:47 +0800 Subject: [PATCH] add fused_activation function for sub,add,mul and div op --- mindspore/lite/schema/ops.fbs | 8 +- mindspore/lite/src/populate_parameter.cc | 17 + .../src/runtime/kernel/arm/fp32/arithmetic.cc | 39 ++- .../src/runtime/kernel/arm/fp32/arithmetic.h | 54 +++- .../kernel/arm/nnacl/arithmetic_common.h | 2 +- .../kernel/arm/nnacl/fp32/activation.h | 2 +- .../kernel/arm/nnacl/fp32/arithmetic.cc | 297 ++++++++++++++++-- .../kernel/arm/nnacl/fp32/arithmetic.h | 8 + 8 files changed, 359 insertions(+), 68 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a870a6bb01..9defbbaea7 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -384,19 +384,19 @@ table Eltwise { } table Add { - activationType : ActivationType; + activationType: ActivationType = 0; } table Sub { - activationType : ActivationType; + activationType: ActivationType = 0; } table Mul { - activationType : ActivationType; + activationType: ActivationType = 0; } table Div { - activationType : ActivationType; + activationType: ActivationType = 0; } table AddGrad { diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index ced8133048..7b4646e20c 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -510,6 +510,23 @@ OpParameter *PopulateArithmetic(const lite::Primitive *primitive) { arithmetic_param->op_parameter_.type_ = primitive->Type(); arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + switch (primitive->Type()) { + case schema::PrimitiveType_Add: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Add()->activationType(); + break; + case schema::PrimitiveType_Sub: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Sub()->activationType(); + break; + case schema::PrimitiveType_Mul: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Mul()->activationType(); + break; + case schema::PrimitiveType_Div: + arithmetic_param->activation_type_ = primitive->Value()->value_as_Div()->activationType(); + break; + default: + arithmetic_param->activation_type_ = 0; + break; + } auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); (void)memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 0dc394b22c..3d3d55f894 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -56,29 +56,26 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { auto input1_data1 = reinterpret_cast(inputs_[1]->Data()); auto output_data = reinterpret_cast(outputs_[0]->Data()); auto element_num = outputs_[0]->ElementsNum(); + + MS_ASSERT(thread_count_ != 0); + int stride = UP_DIV(element_num, thread_count_); + int count = MSMIN(stride, element_num - stride * task_id); + + if (arithmetic_run_ == nullptr) { + MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + return RET_ERROR; + } + + int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - if (arithmetic_broadcast_run_ == nullptr) { - MS_LOG(ERROR) << "broadcasting_run function is nullptr!"; - return RET_ERROR; - } - - MS_ASSERT(thread_count_ != 0); - int stride = UP_DIV(element_num, thread_count_); - int count = MSMIN(stride, element_num - stride * task_id); - - int error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, - output_data + stride * task_id, count); - - if (error_code != RET_OK) { - return RET_ERROR; - } - } else if (arithmetic_run_ != nullptr) { - int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num); - if (error_code != RET_OK) { - return RET_ERROR; - } + error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, + output_data + stride * task_id, count); + } else { - MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, + output_data + stride * task_id, count); + } + if (error_code != RET_OK) { return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index 968f5fbb4e..e723df5f1e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -50,22 +50,59 @@ class ArithmeticCPUKernel : public LiteKernel { ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx) : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { + arithmeticParameter_ = reinterpret_cast(parameter); switch (parameter->type_) { case PrimitiveType_Mul: - arithmetic_run_ = ElementMul; - arithmetic_broadcast_run_ = BroadcastMul; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementMulRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementMulRelu6; + break; + default: + arithmetic_run_ = ElementMul; + break; + } break; case PrimitiveType_Add: - arithmetic_run_ = ElementAdd; - arithmetic_broadcast_run_ = BroadcastAdd; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementAddRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementAddRelu6; + break; + default: + arithmetic_run_ = ElementAdd; + break; + } break; case PrimitiveType_Sub: - arithmetic_run_ = ElementSub; - arithmetic_broadcast_run_ = BroadcastSub; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementSubRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementSubRelu6; + break; + default: + arithmetic_run_ = ElementSub; + break; + } break; case PrimitiveType_Div: - arithmetic_run_ = ElementDiv; - arithmetic_broadcast_run_ = BroadcastDiv; + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmetic_run_ = ElementDivRelu; + break; + case schema::ActivationType_RELU6: + arithmetic_run_ = ElementDivRelu6; + break; + default: + arithmetic_run_ = ElementDiv; + break; + } break; case PrimitiveType_LogicalAnd: arithmetic_run_ = ElementLogicalAnd; @@ -125,7 +162,6 @@ class ArithmeticCPUKernel : public LiteKernel { arithmetic_broadcast_run_ = nullptr; break; } - arithmeticParameter_ = reinterpret_cast(parameter); } ~ArithmeticCPUKernel() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h index dc9f3d48f7..b0e52b8694 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h @@ -27,6 +27,7 @@ struct ArithmeticParameter { OpParameter op_parameter_; bool broadcasting_; size_t ndim_; + int activation_type_; int in_shape0_[5]; int in_shape1_[5]; int out_shape_[5]; @@ -49,4 +50,3 @@ void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t ArithmeticParameter *param); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h index 419dc20f5b..249bfacbce 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h @@ -47,7 +47,7 @@ inline int Relu6(const float *src, int length, float *dst) { inline int LRelu(const float *src, int length, float *dst, float alpha) { for (int i = 0; i < length; ++i) { - dst[i] = src[i] > (src[i] * alpha) ? src[i] : (src[i] * alpha); + dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); } return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc index 285f5c8684..0fc3d805a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc @@ -21,7 +21,7 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) { int block_c4 = element_size - block_mod; for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); float32x4_t vout = vmulq_f32(vin0, vin1); @@ -43,6 +43,73 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) { return NNACL_OK; } +int ElementMulRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmulq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] * input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] * input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] * input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] * input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + + return NNACL_OK; +} + +int ElementMulRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + + return NNACL_OK; +} + int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param) { TileDimensions(input0, input1, tile_input0, tile_input1, param); @@ -54,7 +121,7 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) { int block_c4 = element_size - block_mod; for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); float32x4_t vout = vaddq_f32(vin0, vin1); @@ -75,6 +142,72 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) { return NNACL_OK; } +int ElementAddRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vaddq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] + input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] + input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] + input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] + input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] + input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementAddRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); + } + + return NNACL_OK; +} + int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { for (int i = 0; i < element_size; i++) { output[i] = input0[i] + input1[i]; @@ -99,7 +232,7 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) { int block_c4 = element_size - block_mod; for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); float32x4_t vout = vsubq_f32(vin0, vin1); @@ -120,6 +253,72 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) { return NNACL_OK; } +int ElementSubRelu(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vsubq_f32(vin0, vin1); + vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); + vst1q_f32(output, vout); +#else + float res = input0[0] - input1[0]; + output[0] = res > 0 ? res : 0; + res = input0[1] - input1[1]; + output[1] = res > 0 ? res : 0; + res = input0[2] - input1[2]; + output[2] = res > 0 ? res : 0; + res = input0[3] - input1[3]; + output[3] = res > 0 ? res : 0; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + float res = input0[index] - input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementSubRelu6(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + +#ifdef ENABLE_NEON + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6); + output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6); + output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6); + output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + } + + return NNACL_OK; +} + int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param) { TileDimensions(input0, input1, tile_input0, tile_input1, param); @@ -137,6 +336,27 @@ int ElementDiv(float *input0, float *input1, float *output, int element_size) { return NNACL_OK; } +int ElementDivRelu(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float res = input0[i] / input1[i]; + output[i] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementDivRelu6(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = MSMIN(MSMAX(input0[i] / input1[i], 0), 6); + } + return NNACL_OK; +} + int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param) { TileDimensions(input0, input1, tile_input0, tile_input1, param); @@ -179,11 +399,18 @@ int ElementLogicalAnd(float *input0, float *input1, float *output, int element_s int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; + uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1)); + uint32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON - float32x4_t vin0 = vld1q_f32(input0); - float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vandq_f32(vin0, vin1); +#ifdef ENABLE_NEON + uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue); vst1q_f32(output, vout); #else output[0] = (float)((bool)(input0[0]) & (bool)(input1[0])); @@ -222,11 +449,18 @@ int ElementLogicalOr(float *input0, float *input1, float *output, int element_si int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; +#ifdef ENABLE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; + uint32x4_t mask = vmovq_n_u32((uint32_t(1u << 31) - 1)); + uint32x4_t zeros = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON - float32x4_t vin0 = vld1q_f32(input0); - float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vorrq_f32(vin0, vin1); +#ifdef ENABLE_NEON + uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input0)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(input1)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); vst1q_f32(output, vout); #else output[0] = (float)((bool)(input0[0]) | (bool)(input1[0])); @@ -255,7 +489,7 @@ int ElementMaximum(float *input0, float *input1, float *output, int element_size int block_c4 = element_size - block_mod; for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); float32x4_t vout = vmaxq_f32(vin0, vin1); @@ -287,7 +521,7 @@ int ElementMinimum(float *input0, float *input1, float *output, int element_size int block_c4 = element_size - block_mod; for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); float32x4_t vout = vminq_f32(vin0, vin1); @@ -317,15 +551,15 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti int ElementNotEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vfalse, vtrue); + float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] != input1[0]); @@ -352,15 +586,15 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t int ElementEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vtrue, vfalse); + float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] == input1[0]); @@ -387,15 +621,15 @@ int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile int ElementLess(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vcltq_fp32(vin0, vin1), vtrue, vfalse); + float32x4_t vout = vbslq_f32(vcltq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] < input1[0]); @@ -422,15 +656,15 @@ int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_ int ElementLessEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vcleq_fp32(vin0, vin1), vtrue, vfalse); + float32x4_t vout = vbslq_f32(vcleq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] <= input1[0]); @@ -457,15 +691,15 @@ int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float * int ElementGreater(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vcgtq_fp32(vin0, vin1), vtrue, vfalse); + float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] > input1[0]); @@ -492,15 +726,15 @@ int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *ti int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vfalse = {0, 0, 0, 0}; #endif for (int index = 0; index < block_c4; index += C4NUM) { -#ifdef USE_NEON +#ifdef ENABLE_NEON float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin1 = vld1q_f32(input1); - float32x4_t vout = vbslq_f32(vcgeq_fp32(vin0, vin1), vtrue, vfalse); + float32x4_t vout = vbslq_f32(vcgeq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else output[0] = (float)(input0[0] >= input1[0]); @@ -523,4 +757,3 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa TileDimensions(input0, input1, tile_input0, tile_input1, param); return ElementGreaterEqual(tile_input0, tile_input1, output, element_size); } - diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h index b3a57f83e9..81f388800a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h @@ -24,20 +24,28 @@ #include "src/runtime/kernel/arm/nnacl/errorcode.h" int ElementMul(float *input0, float *input1, float *output, int element_size); +int ElementMulRelu(float *input0, float *input1, float *output, int element_size); +int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); int ElementAdd(float *input0, float *input1, float *output, int element_size); +int ElementAddRelu(float *input0, float *input1, float *output, int element_size); +int ElementAddRelu6(float *input0, float *input1, float *output, int element_size); int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, int element_size, ArithmeticParameter *param); int ElementSub(float *input0, float *input1, float *output, int element_size); +int ElementSubRelu(float *input0, float *input1, float *output, int element_size); +int ElementSubRelu6(float *input0, float *input1, float *output, int element_size); int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); int ElementDiv(float *input0, float *input1, float *output, int element_size); +int ElementDivRelu(float *input0, float *input1, float *output, int element_size); +int ElementDivRelu6(float *input0, float *input1, float *output, int element_size); int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param);