| @@ -470,6 +470,65 @@ int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_ | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { | |||||
| if (param->in_elements_num0_ == 1) { | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| if (input1[index] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| output[index] = input0[0] / input1[index]; | |||||
| } | |||||
| } else { | |||||
| if (input1[0] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| output[index] = input0[index] / input1[0]; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { | |||||
| if (param->in_elements_num0_ == 1) { | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| if (input1[index] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| output[index] = input0[0] / input1[index]; | |||||
| output[index] = output[index] > 0 ? output[index] : 0; | |||||
| } | |||||
| } else { | |||||
| if (input1[0] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| output[index] = input0[index] / input1[0]; | |||||
| output[index] = output[index] > 0 ? output[index] : 0; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { | |||||
| if (param->in_elements_num0_ == 1) { | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| if (input1[index] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6); | |||||
| } | |||||
| } else { | |||||
| if (input1[0] == 0) { | |||||
| return NNACL_ERRCODE_DIVISOR_ZERO; | |||||
| } | |||||
| for (int index = 0; index < element_size; ++index) { | |||||
| output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6); | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ElementMul(float *input0, float *input1, float *output, int element_size) { | int ElementMul(float *input0, float *input1, float *output, int element_size) { | ||||
| int block_mod = element_size % C4NUM; | int block_mod = element_size % C4NUM; | ||||
| int block_c4 = element_size - block_mod; | int block_c4 = element_size - block_mod; | ||||
| @@ -35,6 +35,9 @@ int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_ | |||||
| int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | ||||
| int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | ||||
| int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | ||||
| int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | |||||
| int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | |||||
| int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); | |||||
| int ElementMul(float *input0, float *input1, float *output, int element_size); | int ElementMul(float *input0, float *input1, float *output, int element_size); | ||||
| int ElementMulRelu(float *input0, float *input1, float *output, int element_size); | int ElementMulRelu(float *input0, float *input1, float *output, int element_size); | ||||
| int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); | int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); | ||||
| @@ -94,6 +94,22 @@ int ArithmeticCPUKernel::ReSize() { | |||||
| break; | break; | ||||
| } | } | ||||
| break; | break; | ||||
| case PrimitiveType_Div: | |||||
| switch (arithmeticParameter_->activation_type_) { | |||||
| case schema::ActivationType_RELU: | |||||
| arithmeticParameter_->broadcasting_ = false; | |||||
| arithmetic_opt_run_ = ElementOptDivRelu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| arithmeticParameter_->broadcasting_ = false; | |||||
| arithmetic_opt_run_ = ElementOptDivRelu6; | |||||
| break; | |||||
| default: | |||||
| arithmeticParameter_->broadcasting_ = false; | |||||
| arithmetic_opt_run_ = ElementOptDiv; | |||||
| break; | |||||
| } | |||||
| break; | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -158,6 +158,21 @@ TEST_F(TestArithmeticTestFp32, DivTest) { | |||||
| delete div_param; | delete div_param; | ||||
| } | } | ||||
| TEST_F(TestArithmeticTestFp32, DivTest2) { | |||||
| std::vector<float> in0 = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100}; | |||||
| std::vector<float> in1 = {5, 10, 2, 8, 2, 3, 7, 80, 45, 20}; | |||||
| std::vector<float> correct_out = {2, 2, 15, 5, 25, 20, 10, 1, 2, 5}; | |||||
| constexpr int kOutSize = 10; | |||||
| float out[kOutSize]; | |||||
| ElementDiv(in0.data(), in1.data(), out, kOutSize); | |||||
| std::cout << "out: "; | |||||
| for (int i = 0; i < kOutSize; ++i) { | |||||
| std::cout << out[i] << " "; | |||||
| } | |||||
| std::cout << "\n"; | |||||
| CompareOutputData(out, correct_out.data(), kOutSize, 0.00001); | |||||
| } | |||||
| TEST_F(TestArithmeticTestFp32, FloorDivTest) { | TEST_F(TestArithmeticTestFp32, FloorDivTest) { | ||||
| auto fdiv_param = new ArithmeticParameter(); | auto fdiv_param = new ArithmeticParameter(); | ||||
| fdiv_param->ndim_ = 4; | fdiv_param->ndim_ = 4; | ||||