From 58366cbeb92b486e1afe80be87c4efc06969f5db Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Sat, 30 Jan 2021 10:34:37 +0800 Subject: [PATCH] [MSLITE][DEVELOP] add cpu fp16 op: instance norm --- .../lite/nnacl/fp16/instance_norm_fp16.c | 81 ++++++++++++ .../lite/nnacl/fp16/instance_norm_fp16.h | 31 +++++ mindspore/lite/src/lite_session.cc | 6 +- .../agent/npu/optimizer/npu_fusion_pass.cc | 3 + .../kernel/arm/fp16/instance_norm_fp16.cc | 121 ++++++++++++++++++ .../kernel/arm/fp16/instance_norm_fp16.h | 51 ++++++++ .../runtime/kernel/npu/instance_norm_npu.cc | 2 + 7 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 mindspore/lite/nnacl/fp16/instance_norm_fp16.c create mode 100644 mindspore/lite/nnacl/fp16/instance_norm_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h diff --git a/mindspore/lite/nnacl/fp16/instance_norm_fp16.c b/mindspore/lite/nnacl/fp16/instance_norm_fp16.c new file mode 100644 index 0000000000..c5d286c01b --- /dev/null +++ b/mindspore/lite/nnacl/fp16/instance_norm_fp16.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/instance_norm_fp16.h" +#include +#include "nnacl/errorcode.h" + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); + int channel_begin = task_id * channel_step; + int channel_end = MSMIN(channel_begin + channel_step, param->channel_); + + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * param->channel_ * param->inner_size_; + float16_t *dst_b = dst_data + b * param->channel_ * param->inner_size_; + for (int c = channel_begin; c < channel_end; c++) { + const float16_t *src = src_b + c * param->inner_size_; + float16_t *dst = dst_b + c * param->inner_size_; + float mean = 0.0f; + float square_mean = 0.0f; + + int index = 0; + for (; index < param->inner_size_ - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t squarev = vmulq_f16(srcv, srcv); + + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + mean += vaddvq_f32(sum_f32); + + float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); + float32x4_t square_f32 = vcvt_f32_f16(square2); + square_mean += vaddvq_f32(square_f32); + } + for (; index < param->inner_size_; index++) { + mean += src[index]; + square_mean += src[index] * src[index]; + } + + mean /= (float)param->inner_size_; + square_mean /= (float)param->inner_size_; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index < param->inner_size_ - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + + float16x8_t gammav = vdupq_n_f16(gamma_data[c]); + float16x8_t betav = vdupq_n_f16(beta_data[c]); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < param->inner_size_; index++) { + dst[index] = (src[index] - mean) * deno; + dst[index] = dst[index] * gamma_data[c] + beta_data[c]; + } + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/instance_norm_fp16.h b/mindspore/lite/nnacl/fp16/instance_norm_fp16.h new file mode 100644 index 0000000000..e6fc99331f --- /dev/null +++ b/mindspore/lite/nnacl/fp16/instance_norm_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ +#define MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ + +#include "nnacl/instance_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 7d04b804be..20c8284bd8 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -538,6 +538,9 @@ LiteSession::~LiteSession() { MS_LOG(ERROR) << "Not support multi-threading"; return; } + for (auto *kernel : kernels_) { + delete kernel; + } for (size_t i = 0; i < tensors_.size(); i++) { auto *tensor = tensors_.at(i); MS_ASSERT(tensor != nullptr); @@ -552,9 +555,6 @@ LiteSession::~LiteSession() { output_node_map_.clear(); output_tensor_map_.clear(); input_vec_.clear(); - for (auto *kernel : kernels_) { - delete kernel; - } delete this->context_; delete this->executor_; this->executor_ = nullptr; diff --git a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc index 4687ec549e..fcfcd15f79 100644 --- a/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc +++ b/mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc @@ -141,6 +141,9 @@ void UpdatePreTensors(kernel::LiteKernel *cur_kernel) { void UpdatePostTensors(kernel::LiteKernel *cur_kernel) { auto tensor = cur_kernel->out_tensors()[0]; + tensor->set_format(schema::Format_NCHW); + auto nhwc_shape = tensor->shape(); + tensor->set_shape({nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}); for (auto out_kernel : cur_kernel->out_kernels()) { auto out_tensor = out_kernel->out_tensors()[0]; if (out_kernel->out_kernels().empty()) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc new file mode 100644 index 0000000000..363acbe80d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/instance_norm_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/instance_norm_fp16.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_InstanceNorm; + +namespace mindspore::kernel { +void InstanceNormFp16CPUKernel::FreeTmpBuffer() { + if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { + if (gamma_data_ != nullptr) { + free(gamma_data_); + gamma_data_ = nullptr; + } + } + if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { + if (beta_data_ != nullptr) { + free(beta_data_); + beta_data_ = nullptr; + } + } +} + +int InstanceNormFp16CPUKernel::Init() { + auto gamma = in_tensors_[1]; + if (gamma->data_type() == kNumberTypeFloat32) { + gamma_data_ = reinterpret_cast(malloc(gamma->ElementsNum() * sizeof(float16_t))); + if (gamma_data_ == nullptr) { + MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc gamma_data_ error."; + return RET_ERROR; + } + Float32ToFloat16(reinterpret_cast(gamma->data_c()), gamma_data_, gamma->ElementsNum()); + } else if (gamma->data_type() == kNumberTypeFloat16) { + gamma_data_ = reinterpret_cast(gamma->data_c()); + } else { + MS_LOG(ERROR) << "Unsupported data type of gamma tensor for instance norm."; + return RET_ERROR; + } + + auto beta = in_tensors_[2]; + if (beta->data_type() == kNumberTypeFloat32) { + beta_data_ = reinterpret_cast(malloc(beta->ElementsNum() * sizeof(float16_t))); + if (beta_data_ == nullptr) { + MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc beta_data_ error."; + return RET_ERROR; + } + Float32ToFloat16(reinterpret_cast(beta->data_c()), beta_data_, beta->ElementsNum()); + } else if (beta->data_type() == kNumberTypeFloat16) { + beta_data_ = reinterpret_cast(beta->data_c()); + } else { + MS_LOG(ERROR) << "Unsupported data type of beta tensor for instance norm."; + return RET_ERROR; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int InstanceNormFp16CPUKernel::ReSize() { + param_->op_parameter_.thread_num_ = context_->thread_num_; + auto shape = in_tensors_.front()->shape(); + param_->batch_ = shape[0]; + param_->inner_size_ = shape[2] * shape[3]; + param_->channel_ = shape[1]; + return RET_OK; +} + +int InstanceNormFp16CPUKernel::DoInstanceNorm(int task_id) { + int ret = InstanceNormFp16(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int InstanceNormFp16Run(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoInstanceNorm(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InstanceNormFp16Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int InstanceNormFp16CPUKernel::Run() { + src_data_ = reinterpret_cast(in_tensors_[0]->data_c()); + dst_data_ = reinterpret_cast(out_tensors_[0]->data_c()); + auto ret = ParallelLaunch(this->context_->thread_pool_, InstanceNormFp16Run, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InstanceNormFp16Run error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_InstanceNorm, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h new file mode 100644 index 0000000000..b4b865b3ef --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "nnacl/instance_norm_parameter.h" + +using mindspore::lite::InnerContext; + +namespace mindspore::kernel { +class InstanceNormFp16CPUKernel : public LiteKernel { + public: + InstanceNormFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } + ~InstanceNormFp16CPUKernel() override { FreeTmpBuffer(); }; + + int Init() override; + int ReSize() override; + int Run() override; + int DoInstanceNorm(int task_id); + + private: + void FreeTmpBuffer(); + InstanceNormParameter *param_ = nullptr; + float16_t *src_data_ = nullptr; + float16_t *dst_data_ = nullptr; + float16_t *gamma_data_ = nullptr; + float16_t *beta_data_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ diff --git a/mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc b/mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc index fc38efadd6..f0578dedec 100644 --- a/mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc @@ -54,6 +54,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector &input lite::ConverterToNPUDataType(inputs[1]->data_type())); gamma_tensor->SetTensorDesc(gamma_tensor_desc); gamma_tensor->SetData(reinterpret_cast(inputs[1]->data_c()), inputs[1]->Size()); + gamma->set_attr_value(gamma_tensor); op_->set_input_gamma(*gamma); auto beta = new (std::nothrow) hiai::op::Const(name_ + "_beta"); @@ -71,6 +72,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector &input lite::ConverterToNPUDataType(inputs[2]->data_type())); beta_tensor->SetTensorDesc(beta_tensor_desc); beta_tensor->SetData(reinterpret_cast(inputs[2]->data_c()), inputs[2]->Size()); + beta->set_attr_value(beta_tensor); op_->set_input_beta(*beta); op_->set_attr_epsilon(instance_norm_param_->epsilon_); return RET_OK;