Browse Source

fix reduce relu6

tags/v1.0.0
sunsuodong 5 years ago
parent
commit
5c97d0fb3d
2 changed files with 16 additions and 26 deletions
  1. +10
    -18
      mindspore/lite/nnacl/fp16/activation_fp16.c
  2. +6
    -8
      mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc

+ 10
- 18
mindspore/lite/nnacl/fp16/activation_fp16.c View File

@@ -34,28 +34,20 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
}

int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
int i;
for (i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
int offset = 0;
#ifdef ENABLE_NEON
float16x8_t relu6_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
for (; offset <= ele_num - C8NUM; offset += C8NUM) {
float16x8_t relu6_data = vld1q_f16(data + offset);
relu6_data = vmaxq_f16(relu6_data, zero_data);
relu6_data = vminq_f16(relu6_data, six_data);
vst1q_f16(dst + index, relu6_data);
#else
int j;
for (j = 0; j < C8NUM; ++j) {
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
}
#endif
vst1q_f16(dst + offset, relu6_data);
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
dst[j] = data[j] < 0 ? 0 : data[j];
dst[j] = dst[j] > 6 ? 6 : dst[j];
#endif
for (; offset < ele_num; offset++) {
dst[offset] = data[offset] < 0 ? 0 : data[offset];
dst[offset] = dst[offset] > 6 ? 6 : dst[offset];
}
return NNACL_OK;
}


+ 6
- 8
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc View File

@@ -82,14 +82,7 @@ int ReduceCPUKernel::Init() {
return ReSize();
}

int ReduceCPUKernel::ReSize() {
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
return ReduceBaseCPUKernel::ReSize();
}
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }

int ReduceCPUKernel::CallReduceUnit(int task_id) {
int ret;
@@ -120,6 +113,11 @@ int ReduceCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();


Loading…
Cancel
Save