Merge pull request !4826 from zhaozhenlong/lite/issue/fp16_softmax_activationtags/v1.0.0
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/activation_fp16.h" | |||
| #include "nnacl/errorcode.h" | |||
| int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||
| int i; | |||
| for (i = 0; i < eight_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t relu_src = vld1q_f16(src + index); | |||
| float16x8_t zero_src = vdupq_n_f16(0); | |||
| relu_src = vmaxq_f16(relu_src, zero_src); | |||
| vst1q_f16(dst + index, relu_src); | |||
| #else | |||
| int j; | |||
| for (j = 0; j < C8NUM; j++) { | |||
| dst[index + j] = src[index + j] < 0 ? 0 : src[index + j]; | |||
| } | |||
| #endif | |||
| } | |||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| dst[j] = src[j] < 0 ? 0 : src[j]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { | |||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||
| int i; | |||
| for (i = 0; i < eight_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t relu6_data = vld1q_f16(data + index); | |||
| float16x8_t zero_data = vdupq_n_f16(0); | |||
| float16x8_t six_data = vdupq_n_f16(6); | |||
| relu6_data = vmaxq_f16(relu6_data, zero_data); | |||
| relu6_data = vminq_f16(relu6_data, six_data); | |||
| vst1q_f16(dst + index, relu6_data); | |||
| #else | |||
| int j; | |||
| for (j = 0; j < C8NUM; ++j) { | |||
| dst[index + j] = data[index + j] < 0 ? 0 : data[index + j]; | |||
| dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j]; | |||
| } | |||
| #endif | |||
| } | |||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| dst[j] = data[j] < 0 ? 0 : data[j]; | |||
| dst[j] = dst[j] > 6 ? 6 : dst[j]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i])); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| dst[i] = (float16_t)1.0f - (float16_t)2.0f / (float16_t)(exp(2 * src[i]) + 1); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| float16_t in = src[i]; | |||
| float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6); | |||
| dst[i] = in * relu6 / (float16_t)6.0f; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/quantization/fixed_point.h" | |||
| typedef struct ActivationParameter { | |||
| OpParameter op_parameter_; | |||
| int type_; | |||
| float alpha_; | |||
| } ActivationParameter; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int ReluFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num); | |||
| int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha); | |||
| int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ | |||
| @@ -1,61 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/common_func.h" | |||
| void ReluFp16(float16_t *data, float16_t *dst, int ele_num) { | |||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||
| for (int i = 0; i < eight_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t relu_data = vld1q_f16(data + index); | |||
| float16x8_t zero_data = vdupq_n_f16(0); | |||
| relu_data = vmaxq_f16(relu_data, zero_data); | |||
| vst1q_f16(dst + index, relu_data); | |||
| #else | |||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||
| #endif | |||
| } | |||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| } | |||
| } | |||
| void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) { | |||
| int eight_block = UP_DIV(ele_num, C8NUM); | |||
| for (int i = 0; i < eight_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| #ifdef ENABLE_NEON | |||
| float16x8_t relu6_data = vld1q_f16(data + index); | |||
| float16x8_t zero_data = vdupq_n_f16(0); | |||
| float16x8_t six_data = vdupq_n_f16(6); | |||
| relu6_data = vmaxq_f16(relu6_data, zero_data); | |||
| relu6_data = vminq_f16(relu6_data, six_data); | |||
| vst1q_f16(dst + index, relu6_data); | |||
| #else | |||
| for (int j = 0; j < C8NUM; ++j) { | |||
| data[index + j] = data[index + j] < 0 ? 0 : data[index + j]; | |||
| data[index + j] = data[index + j] > 6 ? 6 : data[index + j]; | |||
| } | |||
| #endif | |||
| } | |||
| for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||
| } | |||
| } | |||
| @@ -41,8 +41,6 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); | |||
| #endif | |||
| void ReluFp16(float16_t *data, float16_t *dst, int ele_num); | |||
| void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/softmax_fp16.h" | |||
| #include <math.h> | |||
| #include <float.h> | |||
| // output = exp(input) / reduce_sum(exp(input), axis) | |||
| void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) { | |||
| int32_t axis = parameter->axis_; | |||
| int n_dim = parameter->n_dim_; | |||
| int ele_size = parameter->element_size_; | |||
| int *input_shape = parameter->input_shape_; | |||
| float16_t max_data = input_ptr[0]; | |||
| for (int i = 0; i < ele_size; i++) { | |||
| max_data = max_data > input_ptr[i] ? max_data : input_ptr[i]; | |||
| } | |||
| for (int i = 0; i < ele_size; i++) { | |||
| output_ptr[i] = exp(input_ptr[i] - max_data); | |||
| } | |||
| int inner_size = 1, outter_size = 1; | |||
| for (int i = 0; i < axis; i++) { | |||
| outter_size *= input_shape[i]; | |||
| } | |||
| for (int i = axis + 1; i < n_dim; i++) { | |||
| inner_size *= input_shape[i]; | |||
| } | |||
| for (int i = 0; i < outter_size; i++) { | |||
| int outter_offset = i * input_shape[axis] * inner_size; | |||
| int sum_outter_offset = i * inner_size; | |||
| for (int k = 0; k < inner_size; k++) { | |||
| int inner_offset = outter_offset + k; | |||
| for (int j = 0; j < input_shape[axis]; j++) { | |||
| int axis_offset = inner_offset + j * inner_size; | |||
| sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; | |||
| } | |||
| } | |||
| } | |||
| for (int i = 0; i < outter_size; i++) { | |||
| int outter_offset = i * input_shape[axis] * inner_size; | |||
| int sum_outter_offset = i * inner_size; | |||
| for (int j = 0; j < input_shape[axis]; j++) { | |||
| int axis_offset = outter_offset + j * inner_size; | |||
| for (int k = 0; k < inner_size; k++) { | |||
| int inner_offset = axis_offset + k; | |||
| output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/softmax_parameter.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ | |||
| @@ -0,0 +1,156 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/activation_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/kernel/arm/fp16/common_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::ActivationType_HSWISH; | |||
| using mindspore::schema::ActivationType_LEAKY_RELU; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::PrimitiveType_Activation; | |||
| namespace mindspore::kernel { | |||
| int ActivationFp16CPUKernel::Init() { return RET_OK; } | |||
| int ActivationFp16CPUKernel::ReSize() { return RET_OK; } | |||
| int ActivationFp16CPUKernel::MallocTmpBuffer() { | |||
| fp16_input_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); | |||
| if (fp16_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| fp16_output_ = MallocOutputFp16(out_tensors_.at(0), context_); | |||
| if (fp16_output_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void ActivationFp16CPUKernel::FreeTmpBuffer() { | |||
| if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { | |||
| if (fp16_input_ != nullptr) { | |||
| context_->allocator->Free(fp16_input_); | |||
| fp16_input_ = nullptr; | |||
| } | |||
| } | |||
| if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { | |||
| if (fp16_output_ != nullptr) { | |||
| context_->allocator->Free(fp16_output_); | |||
| fp16_output_ = nullptr; | |||
| } | |||
| } | |||
| } | |||
| int ActivationFp16CPUKernel::DoActivation(int task_id) { | |||
| auto length = in_tensors_.at(0)->ElementsNum(); | |||
| int stride = UP_DIV(length, thread_count_); | |||
| int count = MSMIN(stride, length - stride * task_id); | |||
| int error_code; | |||
| if (type_ == schema::ActivationType_RELU) { | |||
| error_code = ReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); | |||
| } else if (type_ == schema::ActivationType_RELU6) { | |||
| error_code = Relu6Fp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); | |||
| } else if (type_ == schema::ActivationType_LEAKY_RELU) { | |||
| error_code = LReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count, alpha_); | |||
| } else if (type_ == schema::ActivationType_SIGMOID) { | |||
| error_code = SigmoidFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); | |||
| } else if (type_ == schema::ActivationType_TANH) { | |||
| error_code = TanhFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); | |||
| } else if (type_ == schema::ActivationType_HSWISH) { | |||
| error_code = HSwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); | |||
| } else { | |||
| MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; | |||
| return RET_ERROR; | |||
| } | |||
| return error_code; | |||
| } | |||
| int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto activation_kernel = reinterpret_cast<ActivationFp16CPUKernel *>(cdata); | |||
| auto error_code = activation_kernel->DoActivation(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ActivationRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ActivationFp16CPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return ret; | |||
| } | |||
| ret = MallocTmpBuffer(); | |||
| if (ret != RET_OK) { | |||
| FreeTmpBuffer(); | |||
| return ret; | |||
| } | |||
| int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| auto out_tensor = out_tensors_.at(0); | |||
| if (out_tensor->data_type() == kNumberTypeFloat32) { | |||
| Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(out_tensor->Data()), out_tensor->ElementsNum()); | |||
| } | |||
| FreeTmpBuffer(); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuActivationFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Activation); | |||
| auto *kernel = new (std::nothrow) ActivationFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Activation, CpuActivationFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp16/activation_fp16.h" | |||
| namespace mindspore::kernel { | |||
| class ActivationFp16CPUKernel : public LiteKernel { | |||
| public: | |||
| ActivationFp16CPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||
| type_ = (reinterpret_cast<ActivationParameter *>(param))->type_; | |||
| alpha_ = (float16_t)((reinterpret_cast<ActivationParameter *>(param))->alpha_); | |||
| } | |||
| ~ActivationFp16CPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoActivation(int task_id); | |||
| int MallocTmpBuffer(); | |||
| void FreeTmpBuffer(); | |||
| private: | |||
| int thread_count_; | |||
| int type_; | |||
| float16_t alpha_; | |||
| float16_t *fp16_input_ = nullptr; | |||
| float16_t *fp16_output_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ | |||
| @@ -0,0 +1,156 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string.h> | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp16/softmax_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/common_fp16.h" | |||
| #include "nnacl/fp16/softmax_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_SoftMax; | |||
| namespace mindspore::kernel { | |||
| int SoftmaxFp16CPUKernel::Init() { | |||
| auto ret = SoftmaxBaseCPUKernel::Init(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int SoftmaxFp16CPUKernel::ReSize() { | |||
| return SoftmaxBaseCPUKernel::ReSize(); | |||
| } | |||
| int SoftmaxFp16CPUKernel::MallocTmpBuffer() { | |||
| auto n_dim = softmax_param_->n_dim_; | |||
| auto axis = softmax_param_->axis_; | |||
| if (axis == -1) { | |||
| softmax_param_->axis_ += n_dim; | |||
| axis = softmax_param_->axis_; | |||
| } | |||
| auto in_shape = in_tensors_.front()->shape(); | |||
| int out_plane_size = 1; | |||
| for (int i = 0; i < axis; ++i) { | |||
| out_plane_size *= in_shape[i]; | |||
| } | |||
| int in_plane_size = 1; | |||
| for (int i = axis + 1; i < n_dim; i++) { | |||
| in_plane_size *= in_shape[i]; | |||
| } | |||
| sum_data_ = | |||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(out_plane_size * in_plane_size * sizeof(float16_t))); | |||
| if (sum_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data for softmax fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float16_t)); | |||
| input_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(kInputIndex), context_); | |||
| if (input_fp16_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| output_fp16_ = MallocOutputFp16(out_tensors_.at(kOutputIndex), context_); | |||
| if (output_fp16_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void SoftmaxFp16CPUKernel::FreeTmpBuffer() { | |||
| if (sum_data_ != nullptr) { | |||
| context_->allocator->Free(sum_data_); | |||
| sum_data_ = nullptr; | |||
| } | |||
| if (in_tensors_.at(kInputIndex)->data_type() == kNumberTypeFloat32) { | |||
| if (input_fp16_ != nullptr) { | |||
| context_->allocator->Free(input_fp16_); | |||
| input_fp16_ = nullptr; | |||
| } | |||
| } | |||
| if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) { | |||
| if (output_fp16_ != nullptr) { | |||
| context_->allocator->Free(output_fp16_); | |||
| output_fp16_ = nullptr; | |||
| } | |||
| } | |||
| } | |||
| int SoftmaxFp16CPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| ret = MallocTmpBuffer(); | |||
| if (ret != RET_OK) { | |||
| FreeTmpBuffer(); | |||
| MS_LOG(ERROR) << "MallocTmpBuffer failed"; | |||
| return RET_ERROR; | |||
| } | |||
| SoftmaxFp16(input_fp16_, output_fp16_, sum_data_, softmax_param_); | |||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||
| if (out_tensor->data_type() == kNumberTypeFloat32) { | |||
| Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(out_tensor->Data()), out_tensor->ElementsNum()); | |||
| } | |||
| FreeTmpBuffer(); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuSoftmaxFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); | |||
| auto *kernel = new (std::nothrow) SoftmaxFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new SoftmaxFp16CPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SoftMax, CpuSoftmaxFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/softmax_base.h" | |||
| namespace mindspore::kernel { | |||
| class SoftmaxFp16CPUKernel : public SoftmaxBaseCPUKernel { | |||
| public: | |||
| SoftmaxFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {} | |||
| ~SoftmaxFp16CPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int MallocTmpBuffer(); | |||
| void FreeTmpBuffer(); | |||
| private: | |||
| float16_t *sum_data_ = nullptr; | |||
| float16_t *input_fp16_ = nullptr; | |||
| float16_t *output_fp16_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ | |||