|
|
|
@@ -22,7 +22,6 @@ |
|
|
|
#include "src/runtime/allocator.h" |
|
|
|
#include "nnacl/arithmetic_common.h" |
|
|
|
#include "nnacl/fp32/arithmetic.h" |
|
|
|
#include "schema/ops_generated.h" |
|
|
|
|
|
|
|
typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size); |
|
|
|
typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size, |
|
|
|
@@ -167,9 +166,9 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec |
|
|
|
|
|
|
|
int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, ArithmeticParameter *params) { |
|
|
|
if (kernel_type == KernelType::Mul) { |
|
|
|
if (params->activation_type_ == mindspore::schema::ActivationType_RELU) { |
|
|
|
if (params->activation_type_ == ActivationType::RELU) { |
|
|
|
*arithmetic_run = ElementMulRelu; |
|
|
|
} else if (params->activation_type_ == mindspore::schema::ActivationType_RELU6) { |
|
|
|
} else if (params->activation_type_ == ActivationType::RELU6) { |
|
|
|
*arithmetic_run = ElementMulRelu6; |
|
|
|
} else { |
|
|
|
*arithmetic_run = ElementMul; |
|
|
|
@@ -183,9 +182,9 @@ int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, Arithmeti |
|
|
|
|
|
|
|
int ChooseOptKernel(const int kernel_type, ArithmeticOptRun *arithmetic_opt_run, ArithmeticParameter *params) { |
|
|
|
if (kernel_type == KernelType::Mul) { |
|
|
|
if (params->activation_type_ == mindspore::schema::ActivationType_RELU) { |
|
|
|
if (params->activation_type_ == ActivationType::RELU) { |
|
|
|
*arithmetic_opt_run = ElementOptMulRelu; |
|
|
|
} else if (params->activation_type_ == mindspore::schema::ActivationType_RELU6) { |
|
|
|
} else if (params->activation_type_ == ActivationType::RELU6) { |
|
|
|
*arithmetic_opt_run = ElementOptMulRelu6; |
|
|
|
} else { |
|
|
|
*arithmetic_opt_run = ElementOptMul; |
|
|
|
|