Browse Source

refactor arithmetic

tags/v1.2.0-rc1
sunsuodong 4 years ago
parent
commit
cc0b7ecc0e
11 changed files with 436 additions and 606 deletions
  1. +1
    -3
      mindspore/lite/nnacl/base/arithmetic_base.c
  2. +52
    -51
      mindspore/lite/nnacl/fp16/arithmetic_fp16.c
  3. +51
    -51
      mindspore/lite/nnacl/fp16/arithmetic_fp16.h
  4. +4
    -3
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h
  5. +87
    -172
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  6. +15
    -28
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h
  7. +18
    -18
      mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc
  8. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h
  9. +19
    -19
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc
  10. +162
    -239
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  11. +25
    -20
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h

+ 1
- 3
mindspore/lite/nnacl/base/arithmetic_base.c View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,8 +17,6 @@
#include "nnacl/base/arithmetic_base.h"

void CalcMultiplesAndStrides(ArithmeticParameter *param) {
NNACL_ASSERT(param->in_shape0_[i] != 0);
NNACL_ASSERT(param->in_shape1_[i] != 0);
for (size_t i = 0; i < param->ndim_; i++) {
if (param->in_shape0_[i] != 0) {
param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i];


+ 52
- 51
mindspore/lite/nnacl/fp16/arithmetic_fp16.c View File

@@ -52,7 +52,7 @@ void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_
param->multiples1_);
}

int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -68,7 +68,7 @@ int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int
return NNACL_OK;
}

int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -101,7 +101,7 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i
return NNACL_OK;
}

int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
#endif
@@ -122,7 +122,7 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -160,7 +160,7 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -179,7 +179,7 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -216,7 +216,7 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -238,7 +238,7 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int
return NNACL_OK;
}

int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -271,7 +271,7 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
return NNACL_OK;
}

int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -298,7 +298,7 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -336,7 +336,7 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -364,7 +364,7 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -401,7 +401,7 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -417,7 +417,7 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int
return NNACL_OK;
}

int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -450,7 +450,7 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i
return NNACL_OK;
}

int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -469,7 +469,7 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -507,7 +507,7 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -526,7 +526,7 @@ int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -563,7 +563,7 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -580,7 +580,7 @@ int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int
return NNACL_OK;
}

int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -617,7 +617,7 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i
return NNACL_OK;
}

int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -640,7 +640,7 @@ int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -681,7 +681,7 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t zeros = vdupq_n_f16(0.0);
@@ -703,7 +703,7 @@ int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -744,7 +744,7 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) {
if (input1[i] == 0) {
return NNACL_ERRCODE_DIVISOR_ZERO;
@@ -754,7 +754,7 @@ int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
@@ -770,14 +770,14 @@ int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[i] != 0);
output[i] = floorf(input0[i] / input1[i]);
}
return NNACL_OK;
}
int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
@@ -793,7 +793,7 @@ int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *outp
return NNACL_OK;
}

int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t vtrue = vdupq_n_f16(1);
@@ -813,7 +813,7 @@ int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -854,7 +854,7 @@ int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *ou
return NNACL_OK;
}

int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
float16x8_t vtrue = vdupq_n_f16(1);
@@ -874,7 +874,7 @@ int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output
return NNACL_OK;
}

int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -915,18 +915,19 @@ int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *out
return NNACL_OK;
}

int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size) {
ElementSubFp16(input0, input1, output, element_size);
return ElementMulFp16(output, output, output, element_size);
}

int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size, ArithmeticParameter *param) {
ElementOptSubFp16(input0, input1, output, element_size, param);
return ElementMulFp16(output, output, output, element_size);
}

int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -942,7 +943,7 @@ int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -975,7 +976,7 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -991,7 +992,7 @@ int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output,
return NNACL_OK;
}

int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1024,7 +1025,7 @@ int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
return NNACL_OK;
}

int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1040,7 +1041,7 @@ int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, i
return NNACL_OK;
}

int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1073,7 +1074,7 @@ int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output
return NNACL_OK;
}

int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1089,7 +1090,7 @@ int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int
return NNACL_OK;
}

int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1122,7 +1123,7 @@ int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, i
return NNACL_OK;
}

int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1138,7 +1139,7 @@ int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int e
return NNACL_OK;
}

int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1171,7 +1172,7 @@ int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, in
return NNACL_OK;
}

int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1187,7 +1188,7 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output,
return NNACL_OK;
}

int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1220,7 +1221,7 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *outpu
return NNACL_OK;
}

int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1236,7 +1237,7 @@ int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, in
return NNACL_OK;
}

int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
@@ -1269,7 +1270,7 @@ int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output,
return NNACL_OK;
}

int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size) {
int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) {
int index = 0;
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
@@ -1285,7 +1286,7 @@ int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *outpu
return NNACL_OK;
}

int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) {
#ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]);


+ 51
- 51
mindspore/lite/nnacl/fp16/arithmetic_fp16.h View File

@@ -32,92 +32,92 @@ void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim,
void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
ArithmeticParameter *param);

int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size, ArithmeticParameter *param);
int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
int ElementOptGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);

int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1,
float16_t *out, int size, ArithmeticParameter *param);

int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementDivReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementFloorModFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementFloorDivFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementLogicalAndFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementLogicalOrFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

int ElementNotEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementLessFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);
int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size);

#ifdef __cplusplus
}


+ 4
- 3
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_compare_fp16.h View File

@@ -23,9 +23,10 @@
#include "schema/model_generated.h"

namespace mindspore::kernel {
typedef int (*ArithmeticCompareFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareOptFuncFp16)(float16_t *input0, float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param);
typedef int (*ArithmeticCompareFuncFp16)(const float16_t *input0, const float16_t *input1, uint8_t *output,
int element_size);
typedef int (*ArithmeticCompareOptFuncFp16)(const float16_t *input0, const float16_t *input1, uint8_t *output,
int element_size, ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;


+ 87
- 172
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,16 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/arithmetic_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/ops/arithmetic.h"
#include "nnacl/fp16/cast_fp16.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -49,73 +43,10 @@ using mindspore::schema::PrimitiveType_SquaredDifference;
using mindspore::schema::PrimitiveType_Sub;

namespace mindspore::kernel {
ARITHMETIC_FUNC_INFO_FP16 arithmetic_fun_table_fp16[] = {
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16},
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16},
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16},
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16},
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16},
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16},
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16},
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16,
ElementOptSquaredDifferenceFp16},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}};

ArithmeticFuncFp16 GetArithmeticFun(int primitive_type, int activation_type) {
size_t length = sizeof(arithmetic_fun_table_fp16) / sizeof(ARITHMETIC_FUNC_INFO_FP16);
for (size_t i = 0; i < length; i++) {
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_fun_table_fp16[i].func_;
}
}
return nullptr;
}

ArithmeticOptFuncFp16 GetOptimizedArithmeticFun(int primitive_type, int activation_type) {
size_t length = sizeof(arithmetic_fun_table_fp16) / sizeof(ARITHMETIC_FUNC_INFO_FP16);
for (size_t i = 0; i < length; i++) {
if (arithmetic_fun_table_fp16[i].primitive_type_ == primitive_type &&
arithmetic_fun_table_fp16[i].activation_type_ == activation_type) {
return arithmetic_fun_table_fp16[i].opt_func_;
}
}
return nullptr;
}

int ArithmeticFP16CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

void ArithmeticFP16CPUKernel::InitParam() {
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
param_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
param_->ndim_ = arithmetic_lite_primitive->NDims();

param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(param_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(param_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(param_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));

return;
int ArithmeticFP16CPUKernel::ReSize() {
auto ret = ArithmeticCPUKernel::ReSize();
data_type_len_ = sizeof(float16_t);
return ret;
}

int ArithmeticFP16CPUKernel::CheckDataType() {
@@ -130,131 +61,115 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
return RET_OK;
}

int ArithmeticFP16CPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed.";
return RET_ERROR;
}

InitParam();

if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
param_->broadcasting_ = false;
arithmetic_opt_func_ = GetOptimizedArithmeticFun(param_->op_parameter_.type_, param_->activation_type_);
} else {
arithmetic_func_ = GetArithmeticFun(param_->op_parameter_.type_, param_->activation_type_);
}
if (arithmetic_opt_func_ == nullptr && arithmetic_func_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
return RET_ERROR;
}

if (param_->broadcasting_) {
outside_ = 1;
for (int i = param_->ndim_ - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= param_->out_shape_[i];
void ArithmeticFP16CPUKernel::InitRunFunction() {
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16},
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMulFp16, ElementOptMulFp16},
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddReluFp16, ElementOptAddReluFp16},
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16},
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAddFp16, ElementOptAddFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubReluFp16, ElementOptSubReluFp16},
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16},
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSubFp16, ElementOptSubFp16},
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivReluFp16, ElementOptDivReluFp16},
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16},
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDivFp16, ElementOptDivFp16},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorModFp16, ElementOptFloorModFp16},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDivFp16, ElementOptFloorDivFp16},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAndFp16, ElementOptLogicalAndFp16},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOrFp16, ElementOptLogicalOrFp16},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifferenceFp16,
ElementOptSquaredDifferenceFp16},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximumFp16, ElementOptMaximumFp16},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimumFp16, ElementOptMinimumFp16}};

size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP16);
for (size_t i = 0; i < length; i++) {
if (fun_table[i].primitive_type_ == param_->op_parameter_.type_ &&
fun_table[i].activation_type_ == param_->activation_type_) {
arithmetic_opt_func_ = fun_table[i].opt_func_;
arithmetic_func_ = fun_table[i].func_;
return;
}
ComputeStrides(param_->in_shape0_, param_->in_strides0_, param_->ndim_);
ComputeStrides(param_->in_shape1_, param_->in_strides1_, param_->ndim_);
ComputeStrides(param_->out_shape_, param_->out_strides_, param_->ndim_);
}
return RET_OK;
}

int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
int out_count, int cur_offset) {
if (dim > break_pos_) {
return arithmetic_func_(input0 + cur_offset, input1 + cur_offset, output + cur_offset, out_count);
int ArithmeticFP16CPUKernel::ConstTensorBroadCast() {
int ret;
if (in_tensors_[0]->data_c() != nullptr) {
ret = ConvertFp32TensorToFp16(in_tensors_[0], context_);
if (ret != RET_OK) {
return ret;
}
}
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
int pos0 = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1 = param_->in_shape1_[dim] == 1 ? 0 : i;
int ret = BroadcastRun(input0 + pos0 * param_->in_strides0_[dim], input1 + pos1 * param_->in_strides1_[dim],
output + i * param_->out_strides_[dim], dim + 1, out_count, cur_offset);
if (in_tensors_[1]->data_c() != nullptr) {
ret = ConvertFp32TensorToFp16(in_tensors_[1], context_);
if (ret != RET_OK) {
return RET_ERROR;
return ret;
}
}
return RET_OK;
return ArithmeticCPUKernel::ConstTensorBroadCast();
}

int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
int stride_per_thread = UP_DIV(param_->broadcasting_ ? outside_ : param_->out_elements_num_, context_->thread_num_);
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
if (cur_count <= 0) {
return RET_OK;
}
void ArithmeticFP16CPUKernel::TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape,
const int *in_strides, const int *out_strides, const int *multiple) {
TileOneDimensionFp16(reinterpret_cast<const float16_t *>(in_data), reinterpret_cast<float16_t *>(out_data), 0, ndim,
in_shape, in_strides, out_strides, multiple);
}

int ArithmeticFP16CPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
int ret = RET_OK;
if (param_->broadcasting_) {
ret = BroadcastRun(input0_fp16_, input1_fp16_, output_fp16_, 0, cur_count, cur_offset);
} else if (param_->in_elements_num0_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count, param_);
} else if (param_->in_elements_num1_ == 1) {
ret = arithmetic_opt_func_(input0_fp16_ + cur_offset, input1_fp16_, output_fp16_ + cur_offset, cur_count, param_);
} else {
ret = arithmetic_func_(input0_fp16_ + cur_offset, input1_fp16_ + cur_offset, output_fp16_ + cur_offset, cur_count);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoArithmetic failed, ret = " << ret;
if (in_tensors_[0]->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "data type is not fp16";
return RET_ERROR;
}
return ret;
}
static int ArithmeticsRunFp16(void *cdata, int task_id) {
auto arithmetic_kernel = reinterpret_cast<ArithmeticFP16CPUKernel *>(cdata);
auto ret = arithmetic_kernel->DoArithmetic(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 error task_id[" << task_id << "] ret[" << ret << "]";
if (is_opt) {
CHECK_NULL_RETURN(arithmetic_opt_func_, RET_ERROR);
ret = arithmetic_opt_func_(reinterpret_cast<const float16_t *>(input0), reinterpret_cast<const float16_t *>(input1),
reinterpret_cast<float16_t *>(output), size, param_);
} else {
CHECK_NULL_RETURN(arithmetic_func_, RET_ERROR);
ret = arithmetic_func_(reinterpret_cast<const float16_t *>(input0), reinterpret_cast<const float16_t *>(input1),
reinterpret_cast<float16_t *>(output), size);
}
return ret;
}

int ArithmeticFP16CPUKernel::Run() {
if (!input0_broadcast_) {
input0_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
}
if (!input1_broadcast_) {
input1_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_);
}
auto output_tensor = out_tensors_.at(0);
is_input0_fp32_ = in_tensors_.at(0)->data_type() == kNumberTypeFloat32;
is_input1_fp32_ = in_tensors_.at(1)->data_type() == kNumberTypeFloat32;
is_output_fp32_ = output_tensor->data_type() == kNumberTypeFloat32;

input0_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
input1_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(1), context_);
output_fp16_ = MallocOutputFp16(output_tensor, context_);
if (input0_fp16_ == nullptr || input1_fp16_ == nullptr || output_fp16_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
output_ptr_ = MallocOutputFp16(output_tensor, context_);
if (input0_ptr_ == nullptr || input1_ptr_ == nullptr || output_ptr_ == nullptr) {
FreeFp16Buffer();
return RET_ERROR;
}

auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRunFp16, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRunFp16 run error error_code[" << ret << "]";
}
if (is_output_fp32_) {
Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(output_tensor->MutableData()),
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, context_->thread_num_);
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
Float16ToFloat32(static_cast<float16_t *>(output_ptr_), reinterpret_cast<float *>(output_tensor->MutableData()),
output_tensor->ElementsNum());
}
FreeTmpBuffer();
FreeFp16Buffer();
return ret;
}

void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (is_input0_fp32_) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
void ArithmeticFP16CPUKernel::FreeFp16Buffer() {
if (!input0_broadcast_ && in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(input0_ptr_);
input0_ptr_ = nullptr;
}
if (is_input1_fp32_) {
context_->allocator->Free(input1_fp16_);
input1_fp16_ = nullptr;
if (!input1_broadcast_ && in_tensors_.at(1)->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(input1_ptr_);
input1_ptr_ = nullptr;
}
if (is_output_fp32_) {
context_->allocator->Free(output_fp16_);
output_fp16_ = nullptr;
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(output_ptr_);
output_ptr_ = nullptr;
}
}



+ 15
- 28
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,19 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_

#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "nnacl/fp16/arithmetic_fp16.h"
#include "schema/model_generated.h"

namespace mindspore::kernel {
typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
typedef int (*ArithmeticFuncFp16)(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size);
typedef int (*ArithmeticOptFuncFp16)(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size, ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;
@@ -33,36 +32,24 @@ typedef struct {
ArithmeticOptFuncFp16 opt_func_;
} ARITHMETIC_FUNC_INFO_FP16;

class ArithmeticFP16CPUKernel : public LiteKernel {
class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
public:
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ArithmeticFP16CPUKernel() = default;

int Init() override;
int ReSize() override;
int Run() override;
int CheckDataType();
int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride);

private:
void InitParam();
void FreeTmpBuffer();
int outside_;
int break_pos_;
bool is_input0_fp32_ = false;
bool is_input1_fp32_ = false;
bool is_output_fp32_ = false;
float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
ArithmeticParameter *param_ = nullptr;
void InitRunFunction() override;
int CheckDataType() override;
int ConstTensorBroadCast() override;
void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, const int *in_strides,
const int *out_strides, const int *multiple) override;
int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) override;
void FreeFp16Buffer();
ArithmeticFuncFp16 arithmetic_func_ = nullptr;
ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr;
};


+ 18
- 18
mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "include/errorcode.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

namespace mindspore::kernel {
float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext *ctx) {
@@ -52,23 +55,20 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx)
return fp16_data;
}

bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
bool result = false;
for (auto &input : inputs) {
if (input->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx) {
if (tensor->data_type() == TypeId::kNumberTypeFloat16) {
return RET_OK;
}
if (result) {
return true;
}
for (auto &output : outputs) {
if (output->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
auto fp32_data = tensor->data_c();
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat16);
auto ret = tensor->MallocData();
if (RET_OK != ret) {
MS_LOG(ERROR) << "malloc data failed";
return RET_ERROR;
}
return result;
Float32ToFloat16(static_cast<float *>(fp32_data), static_cast<float16_t *>(tensor->data_c()), tensor->ElementsNum());
ctx->allocator->Free(fp32_data);
return RET_OK;
}
} // namespace mindspore::kernel

+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::InnerContext

float16_t *MallocOutputFp16(lite::Tensor *output, const lite::InnerContext *ctx);

bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
int ConvertFp32TensorToFp16(lite::Tensor *tensor, const lite::InnerContext *ctx);
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_

+ 19
- 19
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc View File

@@ -31,7 +31,7 @@ namespace mindspore::kernel {
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
if (data_type_ == kDataTypeInt) {
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
@@ -40,20 +40,20 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
int error_code;
if (data_type_ == kDataTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * param_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * param_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,
out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
@@ -65,8 +65,8 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o
int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();

MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
MS_ASSERT(context_->thread_num_ != 0);
int stride = UP_DIV(element_num, context_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
@@ -78,14 +78,14 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
}

int error_code;
if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_);
if (param_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, context_->thread_num_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id;
if (out_count <= 0) {
return RET_OK;
}
if (data_type_ == kDataTypeFloat) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
error_code =
BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
@@ -95,7 +95,7 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
error_code = func_fp32_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);


+ 162
- 239
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,13 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "include/errorcode.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/int8/add_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/ops/arithmetic.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
@@ -29,79 +24,119 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Eltwise;

namespace mindspore::kernel {
ArithmeticCPUKernel::~ArithmeticCPUKernel() {
FreeTmpPtr();
return;
}

int ArithmeticCPUKernel::Init() {
InitRunFunction();
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int ArithmeticCPUKernel::InitBroadCastCase() {
/* if const node need broadcast
* and all need-broadcast-node are const
* broadcast in resize */
int ArithmeticCPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed.";
return RET_ERROR;
}
auto prim = reinterpret_cast<const lite::Arithmetic *>(primitive_);
param_->broadcasting_ = prim->Broadcasting();
param_->ndim_ = prim->NDims();
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(param_->in_shape0_, prim->InShape0().data(), prim->InShape0().size() * sizeof(int));
memcpy(param_->in_shape1_, prim->InShape1().data(), prim->InShape1().size() * sizeof(int));
memcpy(param_->out_shape_, prim->OutputShape().data(), prim->OutputShape().size() * sizeof(int));
CalcMultiplesAndStrides(param_);
if (param_->broadcasting_) {
outside_ = 1;
for (auto i = param_->ndim_ - 1; i >= 0; --i) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= param_->out_shape_[i];
}
}
return ConstTensorBroadCast();
}

if (!arithmeticParameter_->broadcasting_) {
return RET_OK;
int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
return RET_ERROR;
}
return RET_OK;
}

int ArithmeticCPUKernel::ConstTensorBroadCast() {
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
if (!param_->broadcasting_) {
return RET_OK;
}
if (out_tensors_[0]->Size() < 0) {
return RET_OK;
}

if (arithmeticParameter_->in_elements_num0_ != arithmeticParameter_->out_elements_num_ &&
arithmeticParameter_->in_elements_num1_ != arithmeticParameter_->out_elements_num_) {
/* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2]
* need broadcast both input */
/* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2], need broadcast both input */
if (param_->in_elements_num0_ != param_->out_elements_num_ &&
param_->in_elements_num1_ != param_->out_elements_num_) {
return RET_OK;
}

if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) &&
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) &&
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) {
/* run opt function
* one of input is scalar */
return RET_OK;
}

FreeTmpPtr();

CalcMultiplesAndStrides(arithmeticParameter_);

if (in_tensors_[0]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
input0_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
FreeConstTileBuff();
if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) {
input0_ptr_ = malloc(param_->out_elements_num_ * data_type_len_);
if (input0_ptr_ == nullptr) {
return RET_ERROR;
}
TileOneDimensionFp32(reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(input0_ptr_), 0,
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape0_,
arithmeticParameter_->in_strides0_, arithmeticParameter_->out_strides_,
arithmeticParameter_->multiples0_);
arithmeticParameter_->broadcasting_ = false;
TileConstTensor(in_tensors_[0]->data_c(), input0_ptr_, param_->ndim_, param_->in_shape0_, param_->in_strides0_,
param_->out_strides_, param_->multiples0_);
input0_broadcast_ = true;
param_->in_elements_num0_ = param_->out_elements_num_;
param_->broadcasting_ = false;
}
if (in_tensors_[1]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
input1_ptr_ = malloc(arithmeticParameter_->out_elements_num_ * sizeof(float));
if (in_tensors_[1]->data_c() != nullptr && param_->in_elements_num1_ != param_->out_elements_num_) {
input1_ptr_ = malloc(param_->out_elements_num_ * data_type_len_);
if (input1_ptr_ == nullptr) {
FreeTmpPtr();
FreeConstTileBuff();
return RET_ERROR;
}
TileOneDimensionFp32(reinterpret_cast<float *>(in_tensors_[1]->data_c()), reinterpret_cast<float *>(input1_ptr_), 0,
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape1_,
arithmeticParameter_->in_strides1_, arithmeticParameter_->out_strides_,
arithmeticParameter_->multiples1_);
arithmeticParameter_->broadcasting_ = false;
TileConstTensor(in_tensors_[1]->data_c(), input1_ptr_, param_->ndim_, param_->in_shape1_, param_->in_strides1_,
param_->out_strides_, param_->multiples1_);
input1_broadcast_ = true;
param_->in_elements_num1_ = param_->out_elements_num_;
param_->broadcasting_ = false;
}
return RET_OK;
}

void ArithmeticCPUKernel::TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape,
const int *in_strides, const int *out_strides, const int *multiple) {
TileOneDimensionFp32(reinterpret_cast<const float *>(in_data), reinterpret_cast<float *>(out_data), 0, ndim, in_shape,
in_strides, out_strides, multiple);
}

void ArithmeticCPUKernel::FreeConstTileBuff() {
if (input0_broadcast_ == true && input0_ptr_ != nullptr) {
free(input0_ptr_);
input0_ptr_ = nullptr;
input0_broadcast_ = false;
}
if (input1_broadcast_ == true && input1_ptr_ != nullptr) {
free(input1_ptr_);
input1_ptr_ = nullptr;
input0_broadcast_ = false;
}
return;
}

void ArithmeticCPUKernel::InitRunFunction() {
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu,
@@ -146,8 +181,8 @@ void ArithmeticCPUKernel::InitRunFunction() {

size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
for (size_t i = 0; i < length; i++) {
if (fun_table[i].primitive_type_ == op_parameter_->type_ &&
fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) {
if (fun_table[i].primitive_type_ == param_->op_parameter_.type_ &&
fun_table[i].activation_type_ == param_->activation_type_) {
arithmetic_run_ = fun_table[i].func_;
arithmetic_run_int_ = fun_table[i].int_func_;
arithmetic_run_bool_ = fun_table[i].bool_func_;
@@ -158,79 +193,53 @@ void ArithmeticCPUKernel::InitRunFunction() {
}
}

void ArithmeticCPUKernel::InitParam() {
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;
int ArithmeticCPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
int ret = RET_OK;
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
if (is_opt) {
CHECK_NULL_RETURN(arithmetic_opt_run_, RET_ERROR);
ret = arithmetic_opt_run_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
reinterpret_cast<float *>(output), size, param_);
} else {
CHECK_NULL_RETURN(arithmetic_run_, RET_ERROR);
ret = arithmetic_run_(reinterpret_cast<const float *>(input0), reinterpret_cast<const float *>(input1),
reinterpret_cast<float *>(output), size);
}
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
data_type_ = KDataTypeBool;
CHECK_NULL_RETURN(arithmetic_run_bool_, RET_ERROR);
ret = arithmetic_run_bool_(reinterpret_cast<const bool *>(input0), reinterpret_cast<const bool *>(input1),
reinterpret_cast<bool *>(output), size);
} else {
data_type_ = kDataTypeInt;
}

arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));

return;
}

int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
return RET_ERROR;
}
return RET_OK;
}

int ArithmeticCPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed.";
return RET_ERROR;
if (is_opt) {
CHECK_NULL_RETURN(arithmetic_opt_run_int_, RET_ERROR);
ret = arithmetic_opt_run_int_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
reinterpret_cast<int *>(output), size, param_);
} else {
CHECK_NULL_RETURN(arithmetic_run_int_, RET_ERROR);
ret = arithmetic_run_int_(reinterpret_cast<const int *>(input0), reinterpret_cast<const int *>(input1),
reinterpret_cast<int *>(output), size);
}
}
InitParam();
return InitBroadCastCase();
return ret;
}

int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) {
if (dim > break_pos_) {
if (data_type_ == kDataTypeInt) {
return arithmetic_run_int_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<int *>(output) + out_thread_stride, out_count);
}
return arithmetic_run_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<float *>(output) + out_thread_stride, out_count);
int offset = out_thread_stride * data_type_len_;
return Execute(static_cast<uint8_t *>(input0) + offset, static_cast<uint8_t *>(input1) + offset,
static_cast<uint8_t *>(output) + offset, out_count, false);
}
for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code;
if (data_type_ == kDataTypeInt) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<int *>(output) + i * arithmeticParameter_->out_strides_[dim], dim + 1,
out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<float *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
int offset_size[] = {param_->in_strides0_[dim] * data_type_len_, param_->in_strides1_[dim] * data_type_len_,
param_->out_strides_[dim] * data_type_len_};
for (int i = 0; i < param_->out_shape_[dim]; ++i) {
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
int ret = BroadcastRun(static_cast<uint8_t *>(input0) + pos0_ * offset_size[0],
static_cast<uint8_t *>(input1) + pos1_ * offset_size[1],
static_cast<uint8_t *>(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride);
if (ret != RET_OK) {
return ret;
}
}
return RET_OK;
@@ -240,20 +249,20 @@ bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
if (input0_broadcast_ || input1_broadcast_) {
return false;
}
if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ ||
arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 ||
param_->in_elements_num1_ == 1) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < arithmeticParameter_->ndim_) {
for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape1_[i] != 1) {
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
@@ -263,16 +272,19 @@ bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
}

int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, thread_count_);
if (break_pos_ < 1) {
return RET_ERROR;
}
int batch = param_->out_elements_num_ / param_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, context_->thread_num_);

int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;

int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1];
int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1];
int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1];
int stride0 = param_->in_strides0_[break_pos_ - 1] * data_type_len_;
int stride1 = param_->in_strides1_[break_pos_ - 1] * data_type_len_;
int out_stride = param_->out_strides_[break_pos_ - 1] * data_type_len_;

int offset0 = stride0 * start_batch;
int offset1 = stride1 * start_batch;
@@ -280,15 +292,8 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {

int ret = RET_OK;
for (int i = 0; i < batch_size; i++) {
if (data_type_ == kDataTypeFloat) {
ret = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_) + offset0, reinterpret_cast<float *>(input1_ptr_) + offset1,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
} else {
ret = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_) + offset0, reinterpret_cast<int *>(input1_ptr_) + offset1,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
}
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset0, static_cast<uint8_t *>(input1_ptr_) + offset1,
static_cast<uint8_t *>(output_ptr_) + out_offset, param_->out_strides_[break_pos_ - 1], true);
offset0 += stride0;
offset1 += stride1;
out_offset += out_stride;
@@ -298,143 +303,61 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {

int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();

MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int stride = UP_DIV(element_num, context_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}

if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}
/* run opt function, every batch one of input is scalar */
if (CanBatchScalar()) {
return BatchScalarCalc(task_id);
}
int error_code = RET_OK;
if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) &&
int offset = stride * task_id * data_type_len_;
/* run opt function, one of input is scalar */
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) &&
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) {
/* run opt function
* one of input is scalar */
if (arithmeticParameter_->in_elements_num0_ == 1) {
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
}
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_) + stride * task_id, reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
}
if (param_->in_elements_num0_ == 1) {
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
} else if (param_->in_elements_num1_ == 1) {
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, input1_ptr_,
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
}
return error_code;
}
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
/* need broadcast in runtime */
if (param_->broadcasting_) {
stride = UP_DIV(outside_, context_->thread_num_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code;
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
}

/* no broadcast in runtime */
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_run_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else if (data_type_ == KDataTypeBool) {
error_code = arithmetic_run_bool_(reinterpret_cast<bool *>(input0_ptr_) + stride * task_id,
reinterpret_cast<bool *>(input1_ptr_) + stride * task_id,
reinterpret_cast<bool *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = arithmetic_run_int_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
return error_code;
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
}

int ArithmeticsRun(void *cdata, int task_id) {
auto arithmetic_kernel = reinterpret_cast<ArithmeticCPUKernel *>(cdata);
auto error_code = arithmetic_kernel->DoArithmetic(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
auto kernel = reinterpret_cast<ArithmeticCPUKernel *>(cdata);
auto ret = kernel->DoArithmetic(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << ret << "]";
}
return RET_OK;
}

void ArithmeticCPUKernel::FreeTmpPtr() {
if (input0_broadcast_ == true && input0_ptr_ != nullptr) {
free(input0_ptr_);
input0_ptr_ = nullptr;
input0_broadcast_ = false;
}
if (input1_broadcast_ == true && input1_ptr_ != nullptr) {
free(input1_ptr_);
input1_ptr_ = nullptr;
input0_broadcast_ = false;
}
return;
return ret;
}

void ArithmeticCPUKernel::InitParamInRunTime() {
/* after infershape */
if (arithmeticParameter_->broadcasting_) {
outside_ = 1;
for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_pos_ = i;
break;
}
outside_ *= arithmeticParameter_->out_shape_[i];
}
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);

int ArithmeticCPUKernel::Run() {
if (!input0_broadcast_) {
input0_ptr_ = in_tensors_[0]->data_c();
}
if (!input1_broadcast_) {
input1_ptr_ = in_tensors_[1]->data_c();
}
return;
output_ptr_ = out_tensors_[0]->data_c();
return ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, context_->thread_num_);
}

int ArithmeticCPUKernel::Run() {
InitParamInRunTime();

int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mul, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, LiteKernelCreator<ArithmeticCPUKernel>)


+ 25
- 20
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,14 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "schema/model_generated.h"

using mindspore::schema::PrimitiveType_Add;
using mindspore::schema::PrimitiveType_Div;
@@ -42,6 +40,14 @@ using mindspore::schema::PrimitiveType_RealDiv;
using mindspore::schema::PrimitiveType_SquaredDifference;
using mindspore::schema::PrimitiveType_Sub;

#define CHECK_NULL_RETURN(ptr, errcode) \
do { \
if (ptr == nullptr) { \
MS_LOG(ERROR) << "ptr must not be null."; \
return errcode; \
} \
} while (0);

namespace mindspore::kernel {
class ArithmeticCPUKernel : public LiteKernel {
typedef int (*ArithmeticRun)(const float *input0, const float *input1, float *output, const int element_size);
@@ -66,11 +72,10 @@ class ArithmeticCPUKernel : public LiteKernel {
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
InitRunFunction();
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticCPUKernel() override;
~ArithmeticCPUKernel() { FreeConstTileBuff(); }

int Init() override;
int ReSize() override;
@@ -78,33 +83,33 @@ class ArithmeticCPUKernel : public LiteKernel {
virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);

private:
void InitRunFunction();
void InitParam();
void FreeTmpPtr();
int CheckDataType();
int InitBroadCastCase();
void InitParamInRunTime();
bool CanBatchScalar();
int BatchScalarCalc(int task_id);

protected:
virtual void InitRunFunction();
virtual int CheckDataType();
virtual int ConstTensorBroadCast();
virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape,
const int *in_strides, const int *out_strides, const int *multiple);
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
bool input0_broadcast_ = false;
bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr;
void *input1_ptr_ = nullptr;
void *output_ptr_ = nullptr;
int break_pos_ = 0;
int outside_ = 0;
int thread_count_ = 1;
ArithmeticParameter *arithmeticParameter_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
ArithmeticParameter *param_ = nullptr;
int data_type_len_ = sizeof(float);

private:
bool CanBatchScalar();
int BatchScalarCalc(int task_id);
void FreeConstTileBuff();
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
ArithmeticBoolRun arithmetic_run_bool_ = nullptr;
};
int ArithmeticsRun(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

Loading…
Cancel
Save