| @@ -16,6 +16,7 @@ | |||
| #include "nnacl/fp32/exp_fp32.h" | |||
| #include <math.h> | |||
| #include <string.h> | |||
| #include "nnacl/errorcode.h" | |||
| int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id) { | |||
| @@ -35,3 +36,40 @@ int Exp(const float *input_data, float *output_data, ExpParameter *parameter, in | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| void ExpFp32(const float *src, float *dst, int num) { | |||
| int i = 0; | |||
| const float param[] = {log(2.0f), 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t maxv = vdupq_n_f32(88.0f); | |||
| float32x4_t minv = vdupq_n_f32(-88.0f); | |||
| float32x4_t param0 = vdupq_n_f32(log(2.0f)); | |||
| float32x4_t param1 = vdupq_n_f32(1.0f / 120); | |||
| float32x4_t param2 = vdupq_n_f32(1.0f / 24); | |||
| float32x4_t param3 = vdupq_n_f32(1.0f / 6); | |||
| float32x4_t param4 = vdupq_n_f32(0.5f); | |||
| float32x4_t param5 = vdupq_n_f32(1.0f); | |||
| for (; i < num - C4NUM; i += C4NUM) { | |||
| float32x4_t input4 = vmaxq_f32(minv, vminq_f32(maxv, vld1q_f32(src + i))); | |||
| int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, param0)); | |||
| float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), param0)); | |||
| int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23)); | |||
| vst1q_f32(dst + i, vld1q_f32((float32_t *)(&int_exp4))); | |||
| float32x4_t decimal_exp4 = vaddq_f32(param2, vmulq_f32(decimal4, param1)); | |||
| decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(param3, vmulq_f32(decimal4, decimal_exp4))); | |||
| decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, vaddq_f32(param4, decimal_exp4))); | |||
| decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, decimal_exp4)); | |||
| vst1q_f32(dst + i, vmulq_f32(vld1q_f32(dst + i), decimal_exp4)); | |||
| } | |||
| #endif | |||
| for (; i < num; ++i) { | |||
| float input = MSMAX(-88.0f, MSMIN(88.0f, src[i])); | |||
| int integer = input / param[0]; | |||
| float decimal = input - integer * param[0]; | |||
| int int_exp = (integer + 127) << 23; | |||
| memcpy(dst + i, &int_exp, sizeof(float)); | |||
| float decimal_exp = | |||
| 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); | |||
| dst[i] *= decimal_exp; | |||
| } | |||
| } | |||
| @@ -34,6 +34,7 @@ typedef struct ExpParameter { | |||
| extern "C" { | |||
| #endif | |||
| int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id); | |||
| void ExpFp32(const float *src, float *dst, int num); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -13,9 +13,79 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/softmax_fp32.h" | |||
| #include <math.h> | |||
| #include "nnacl/fp32/exp_fp32.h" | |||
| void SoftmaxNorm(const float *src, float *dst, int batch, int channel) { | |||
| int cur_batch_offset = 0; | |||
| for (int i = 0; i < batch; i++, cur_batch_offset += channel) { | |||
| int j = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t max4 = vld1q_f32(src + cur_batch_offset); | |||
| j += C4NUM; | |||
| for (; j < channel - C4NUM; j += C4NUM) { | |||
| float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j); | |||
| max4 = vmaxq_f32(max4, input4); | |||
| } | |||
| float max = channel >= C4NUM ? vmaxvq_f32(max4) : src[cur_batch_offset]; | |||
| #else | |||
| float max = src[cur_batch_offset]; | |||
| #endif | |||
| for (; j < channel; j++) { | |||
| float input = src[cur_batch_offset + j]; | |||
| if (input > max) { | |||
| max = input; | |||
| } | |||
| } | |||
| int k = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; k < channel - C4NUM; k += C4NUM) { | |||
| float32x4_t input4 = vld1q_f32(src + cur_batch_offset + k); | |||
| float32x4_t output4 = vsubq_f32(input4, vdupq_n_f32(max)); | |||
| vst1q_f32(dst + cur_batch_offset + k, output4); | |||
| } | |||
| #endif | |||
| for (; k < channel; k++) { | |||
| int offset = cur_batch_offset + k; | |||
| dst[offset] = src[offset] - max; | |||
| } | |||
| } | |||
| } | |||
| void SumAndDiv(const float *src, float *dst, int batch, int channel) { | |||
| int cur_batch_offset = 0; | |||
| for (int i = 0; i < batch; i++, cur_batch_offset += channel) { | |||
| float sum = 0; | |||
| int j = 0; | |||
| #ifdef ENABLE_NEON | |||
| float32x4_t sum4 = vdupq_n_f32(0); | |||
| for (; j < channel - C4NUM; j += C4NUM) { | |||
| sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j)); | |||
| } | |||
| sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; | |||
| #endif | |||
| for (; j < channel; j++) { | |||
| sum += src[cur_batch_offset + j]; | |||
| } | |||
| int k = 0; | |||
| #ifdef ENABLE_NEON | |||
| float div = 1.0f / sum; | |||
| for (; k < channel - C4NUM; k += C4NUM) { | |||
| vst1q_f32(dst + cur_batch_offset + k, vmulq_n_f32(vld1q_f32(src + cur_batch_offset + k), div)); | |||
| } | |||
| #endif | |||
| for (; k < channel; k++) { | |||
| dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum; | |||
| } | |||
| } | |||
| } | |||
| void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) { | |||
| SoftmaxNorm(src, dst, batch, channel); | |||
| ExpFp32(dst, dst, batch * channel); | |||
| SumAndDiv(dst, dst, batch, channel); | |||
| } | |||
| // output = exp(input) / reduce_sum(exp(input), axis) | |||
| void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) { | |||
| @@ -23,6 +23,7 @@ | |||
| extern "C" { | |||
| #endif | |||
| void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); | |||
| void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -70,12 +70,41 @@ int SoftmaxCPUKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxCPUKernel::Run() { | |||
| memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float)); | |||
| int SoftmaxCPUKernel::DoSoftmaxLastAxis(int task_id) { | |||
| int unit = UP_DIV(out_plane_size_, context_->thread_num_); | |||
| int begin = task_id * unit; | |||
| int end = MSMIN(begin + unit, out_plane_size_); | |||
| int channel = softmax_param_->input_shape_[softmax_param_->axis_]; | |||
| int offset = begin * channel; | |||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData()); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | |||
| Softmax(input_ptr, output_ptr, sum_data_, softmax_param_); | |||
| SoftmaxLastAxis(input_ptr + offset, output_ptr + offset, end - begin, channel); | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxLastAxisRun(void *cdata, int task_id) { | |||
| auto kernel = reinterpret_cast<SoftmaxCPUKernel *>(cdata); | |||
| auto ret = kernel->DoSoftmaxLastAxis(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoSoftmaxLastAxis error task_id: " << task_id << ", ret: " << ret; | |||
| } | |||
| return ret; | |||
| } | |||
| int SoftmaxCPUKernel::Run() { | |||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData()); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | |||
| int ret = RET_OK; | |||
| if (in_plane_size_ == 1) { | |||
| ret = ParallelLaunch(this->context_->thread_pool_, SoftmaxLastAxisRun, this, context_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxCPUKernel ParallelLaunch failed, ret: " << ret; | |||
| } | |||
| } else { | |||
| memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float)); | |||
| Softmax(input_ptr, output_ptr, sum_data_, softmax_param_); | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -37,6 +37,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoSoftmaxLastAxis(int task_id); | |||
| private: | |||
| float *sum_data_ = nullptr; | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "nnacl/softmax_parameter.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| namespace mindspore { | |||
| class TestSoftmaxFp32 : public mindspore::CommonTest { | |||
| public: | |||
| TestSoftmaxFp32() {} | |||
| }; | |||
| TEST_F(TestSoftmaxFp32, 001) { | |||
| lite::Tensor in_tensor(kNumberTypeFloat32, {2, 1, 1, 5}); | |||
| lite::Tensor out_tensor(kNumberTypeFloat32, {2, 1, 1, 5}); | |||
| float input_data[] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| float output_data[10] = {0}; | |||
| in_tensor.set_data(input_data); | |||
| out_tensor.set_data(output_data); | |||
| std::vector<lite::Tensor *> inputs = {&in_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| SoftmaxParameter parameter = {{}, -1, 10, 4, {2, 1, 1, 5}}; | |||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SoftMax}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| ASSERT_NE(creator, nullptr); | |||
| auto ctx = std::make_shared<lite::InnerContext>(); | |||
| ASSERT_EQ(lite::RET_OK, ctx->Init()); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr); | |||
| ASSERT_NE(kernel, nullptr); | |||
| auto ret = kernel->Run(); | |||
| EXPECT_EQ(0, ret); | |||
| float expect[] = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; | |||
| for (size_t i = 0; i < sizeof(expect) / sizeof(expect[0]); ++i) { | |||
| EXPECT_EQ(output_data[i], expect[i]); | |||
| } | |||
| in_tensor.set_data(nullptr); | |||
| out_tensor.set_data(nullptr); | |||
| } | |||
| } // namespace mindspore | |||