Browse Source

!5907 optimize op maximum minimum greater

Merge pull request !5907 from 陶云浩/lite
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a2002da77c
2 changed files with 39 additions and 43 deletions
  1. +20
    -20
      mindspore/lite/nnacl/fp16/arithmetic_fp16.c
  2. +19
    -23
      mindspore/lite/nnacl/fp32/arithmetic.c

+ 20
- 20
mindspore/lite/nnacl/fp16/arithmetic_fp16.c View File

@@ -1321,19 +1321,14 @@ int ElementOptSquaredDifferenceFp16(float16_t *input0, float16_t *input1, float1
} }


int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmaxq_f16(vin0, vin1); float16x8_t vout = vmaxq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i], input1[i]);
}
#endif
input0 += C8NUM; input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
@@ -1341,6 +1336,11 @@ int ElementMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index], input1[index]); output[index] = MSMAX(input0[index], input1[index]);
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMAX(input0[index], input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
@@ -1394,19 +1394,14 @@ int ElementOptMaximumFp16(float16_t *input0, float16_t *input1, float16_t *outpu
} }


int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vminq_f16(vin0, vin1); float16x8_t vout = vminq_f16(vin0, vin1);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(input0[i], input1[i]);
}
#endif
input0 += C8NUM; input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
@@ -1414,6 +1409,11 @@ int ElementMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(input0[index], input1[index]); output[index] = MSMIN(input0[index], input1[index]);
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMIN(input0[index], input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }
int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, int ElementOptMinimumFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
@@ -1783,23 +1783,18 @@ int ElementOptLessEqualFp16(float16_t *input0, float16_t *input1, float16_t *out
} }


int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C8NUM; int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod; int block_c8 = element_size - block_mod;
#ifdef ENABLE_NEON
float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1}; float16x8_t vtrue = {1, 1, 1, 1, 1, 1, 1, 1};
float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t vfalse = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) { for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0); float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vld1q_f16(input1); float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse); float16x8_t vout = vbslq_f16(vcgtq_f16(vin0, vin1), vtrue, vfalse);
vst1q_f16(output, vout); vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = (float16_t)(input0[i] > input1[i]);
}
#endif

input0 += C8NUM; input0 += C8NUM;
input1 += C8NUM; input1 += C8NUM;
output += C8NUM; output += C8NUM;
@@ -1807,6 +1802,11 @@ int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output,
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = (float16_t)(input0[index] > input1[index]); output[index] = (float16_t)(input0[index] > input1[index]);
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = (float16_t)(input0[index] > input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }
int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, int ElementOptGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,


+ 19
- 23
mindspore/lite/nnacl/fp32/arithmetic.c View File

@@ -997,21 +997,15 @@ int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *
} }


int ElementMaximum(float *input0, float *input1, float *output, int element_size) { int ElementMaximum(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM; int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod; int block_c4 = element_size - block_mod;


for (int index = 0; index < block_c4; index += C4NUM) { for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1); float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vin0, vin1); float32x4_t vout = vmaxq_f32(vin0, vin1);
vst1q_f32(output, vout); vst1q_f32(output, vout);
#else
output[0] = input0[0] > input1[0] ? input0[0] : input1[0];
output[1] = input0[1] > input1[1] ? input0[1] : input1[1];
output[2] = input0[2] > input1[2] ? input0[2] : input1[2];
output[3] = input0[3] > input1[3] ? input0[3] : input1[3];
#endif
input0 += C4NUM; input0 += C4NUM;
input1 += C4NUM; input1 += C4NUM;
output += C4NUM; output += C4NUM;
@@ -1019,6 +1013,11 @@ int ElementMaximum(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] > input1[index] ? input0[index] : input1[index]; output[index] = input0[index] > input1[index] ? input0[index] : input1[index];
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMAX(input0[index], input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }


@@ -1029,21 +1028,15 @@ int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *ti
} }


int ElementMinimum(float *input0, float *input1, float *output, int element_size) { int ElementMinimum(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM; int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod; int block_c4 = element_size - block_mod;


for (int index = 0; index < block_c4; index += C4NUM) { for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1); float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vin0, vin1); float32x4_t vout = vminq_f32(vin0, vin1);
vst1q_f32(output, vout); vst1q_f32(output, vout);
#else
output[0] = input0[0] > input1[0] ? input1[0] : input0[0];
output[1] = input0[1] > input1[1] ? input1[1] : input0[1];
output[2] = input0[2] > input1[2] ? input1[2] : input0[2];
output[3] = input0[3] > input1[3] ? input1[3] : input0[3];
#endif
input0 += C4NUM; input0 += C4NUM;
input1 += C4NUM; input1 += C4NUM;
output += C4NUM; output += C4NUM;
@@ -1051,6 +1044,11 @@ int ElementMinimum(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = MSMIN(input0[index], input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }


@@ -1217,24 +1215,17 @@ int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *
} }


int ElementGreater(float *input0, float *input1, float *output, int element_size) { int ElementGreater(float *input0, float *input1, float *output, int element_size) {
#ifdef ENABLE_NEON
int block_mod = element_size % C4NUM; int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod; int block_c4 = element_size - block_mod;
#ifdef ENABLE_NEON
float32x4_t vtrue = {1, 1, 1, 1}; float32x4_t vtrue = {1, 1, 1, 1};
float32x4_t vfalse = {0, 0, 0, 0}; float32x4_t vfalse = {0, 0, 0, 0};
#endif
for (int index = 0; index < block_c4; index += C4NUM) { for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0); float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vld1q_f32(input1); float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse); float32x4_t vout = vbslq_f32(vcgtq_f32(vin0, vin1), vtrue, vfalse);
vst1q_f32(output, vout); vst1q_f32(output, vout);
#else
output[0] = (float)(input0[0] > input1[0]);
output[1] = (float)(input0[1] > input1[1]);
output[2] = (float)(input0[2] > input1[2]);
output[3] = (float)(input0[3] > input1[3]);
#endif
input0 += C4NUM; input0 += C4NUM;
input1 += C4NUM; input1 += C4NUM;
output += C4NUM; output += C4NUM;
@@ -1242,6 +1233,11 @@ int ElementGreater(float *input0, float *input1, float *output, int element_size
for (int index = 0; index < block_mod; ++index) { for (int index = 0; index < block_mod; ++index) {
output[index] = (float)(input0[index] > input1[index]); output[index] = (float)(input0[index] > input1[index]);
} }
#else
for (int index = 0; index < element_size; ++index) {
output[index] = (float)(input0[index] > input1[index]);
}
#endif
return NNACL_OK; return NNACL_OK;
} }




Loading…
Cancel
Save