Browse Source

!11899 [MSLITE][DEVELOP] add cpu fp16 op: instance norm

From: @yangruoqi713
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
e3a01fea1c
7 changed files with 292 additions and 3 deletions
  1. +81
    -0
      mindspore/lite/nnacl/fp16/instance_norm_fp16.c
  2. +31
    -0
      mindspore/lite/nnacl/fp16/instance_norm_fp16.h
  3. +3
    -3
      mindspore/lite/src/lite_session.cc
  4. +3
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc
  5. +121
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc
  6. +51
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h
  7. +2
    -0
      mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc

+ 81
- 0
mindspore/lite/nnacl/fp16/instance_norm_fp16.c View File

@@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/instance_norm_fp16.h"
#include <math.h>
#include "nnacl/errorcode.h"

int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data,
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_);
int channel_begin = task_id * channel_step;
int channel_end = MSMIN(channel_begin + channel_step, param->channel_);

for (int b = 0; b < param->batch_; b++) {
const float16_t *src_b = src_data + b * param->channel_ * param->inner_size_;
float16_t *dst_b = dst_data + b * param->channel_ * param->inner_size_;
for (int c = channel_begin; c < channel_end; c++) {
const float16_t *src = src_b + c * param->inner_size_;
float16_t *dst = dst_b + c * param->inner_size_;
float mean = 0.0f;
float square_mean = 0.0f;

int index = 0;
for (; index < param->inner_size_ - C8NUM; index += C8NUM) {
float16x8_t srcv = vld1q_f16(src + index);
float16x8_t squarev = vmulq_f16(srcv, srcv);

float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv));
float32x4_t sum_f32 = vcvt_f32_f16(sum2);
mean += vaddvq_f32(sum_f32);

float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev));
float32x4_t square_f32 = vcvt_f32_f16(square2);
square_mean += vaddvq_f32(square_f32);
}
for (; index < param->inner_size_; index++) {
mean += src[index];
square_mean += src[index] * src[index];
}

mean /= (float)param->inner_size_;
square_mean /= (float)param->inner_size_;
const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_);

index = 0;
float16x8_t meanv = vdupq_n_f16(mean);
float16x8_t denov = vdupq_n_f16(deno);
for (; index < param->inner_size_ - C8NUM; index += C8NUM) {
float16x8_t srcv = vld1q_f16(src + index);
float16x8_t outv = vsubq_f16(srcv, meanv);
outv = vmulq_f16(outv, denov);

float16x8_t gammav = vdupq_n_f16(gamma_data[c]);
float16x8_t betav = vdupq_n_f16(beta_data[c]);
outv = vmulq_f16(outv, gammav);
outv = vaddq_f16(outv, betav);
vst1q_f16(dst + index, outv);
}
for (; index < param->inner_size_; index++) {
dst[index] = (src[index] - mean) * deno;
dst[index] = dst[index] * gamma_data[c] + beta_data[c];
}
}
}
return NNACL_OK;
}

+ 31
- 0
mindspore/lite/nnacl/fp16/instance_norm_fp16.h View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_
#define MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_

#include "nnacl/instance_norm_parameter.h"

#ifdef __cplusplus
extern "C" {
#endif

int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data,
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id);
#ifdef __cplusplus
}
#endif

#endif // MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_

+ 3
- 3
mindspore/lite/src/lite_session.cc View File

@@ -538,6 +538,9 @@ LiteSession::~LiteSession() {
MS_LOG(ERROR) << "Not support multi-threading";
return;
}
for (auto *kernel : kernels_) {
delete kernel;
}
for (size_t i = 0; i < tensors_.size(); i++) {
auto *tensor = tensors_.at(i);
MS_ASSERT(tensor != nullptr);
@@ -552,9 +555,6 @@ LiteSession::~LiteSession() {
output_node_map_.clear();
output_tensor_map_.clear();
input_vec_.clear();
for (auto *kernel : kernels_) {
delete kernel;
}
delete this->context_;
delete this->executor_;
this->executor_ = nullptr;


+ 3
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc View File

@@ -141,6 +141,9 @@ void UpdatePreTensors(kernel::LiteKernel *cur_kernel) {

void UpdatePostTensors(kernel::LiteKernel *cur_kernel) {
auto tensor = cur_kernel->out_tensors()[0];
tensor->set_format(schema::Format_NCHW);
auto nhwc_shape = tensor->shape();
tensor->set_shape({nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]});
for (auto out_kernel : cur_kernel->out_kernels()) {
auto out_tensor = out_kernel->out_tensors()[0];
if (out_kernel->out_kernels().empty()) {


+ 121
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.cc View File

@@ -0,0 +1,121 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/instance_norm_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/instance_norm_fp16.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_InstanceNorm;

namespace mindspore::kernel {
void InstanceNormFp16CPUKernel::FreeTmpBuffer() {
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
if (gamma_data_ != nullptr) {
free(gamma_data_);
gamma_data_ = nullptr;
}
}
if (in_tensors_[2]->data_type() == kNumberTypeFloat32) {
if (beta_data_ != nullptr) {
free(beta_data_);
beta_data_ = nullptr;
}
}
}

int InstanceNormFp16CPUKernel::Init() {
auto gamma = in_tensors_[1];
if (gamma->data_type() == kNumberTypeFloat32) {
gamma_data_ = reinterpret_cast<float16_t *>(malloc(gamma->ElementsNum() * sizeof(float16_t)));
if (gamma_data_ == nullptr) {
MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc gamma_data_ error.";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(gamma->data_c()), gamma_data_, gamma->ElementsNum());
} else if (gamma->data_type() == kNumberTypeFloat16) {
gamma_data_ = reinterpret_cast<float16_t *>(gamma->data_c());
} else {
MS_LOG(ERROR) << "Unsupported data type of gamma tensor for instance norm.";
return RET_ERROR;
}

auto beta = in_tensors_[2];
if (beta->data_type() == kNumberTypeFloat32) {
beta_data_ = reinterpret_cast<float16_t *>(malloc(beta->ElementsNum() * sizeof(float16_t)));
if (beta_data_ == nullptr) {
MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc beta_data_ error.";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(beta->data_c()), beta_data_, beta->ElementsNum());
} else if (beta->data_type() == kNumberTypeFloat16) {
beta_data_ = reinterpret_cast<float16_t *>(beta->data_c());
} else {
MS_LOG(ERROR) << "Unsupported data type of beta tensor for instance norm.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int InstanceNormFp16CPUKernel::ReSize() {
param_->op_parameter_.thread_num_ = context_->thread_num_;
auto shape = in_tensors_.front()->shape();
param_->batch_ = shape[0];
param_->inner_size_ = shape[2] * shape[3];
param_->channel_ = shape[1];
return RET_OK;
}

int InstanceNormFp16CPUKernel::DoInstanceNorm(int task_id) {
int ret = InstanceNormFp16(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

int InstanceNormFp16Run(void *cdata, int task_id) {
auto kernel = reinterpret_cast<InstanceNormFp16CPUKernel *>(cdata);
auto ret = kernel->DoInstanceNorm(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InstanceNormFp16Run error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}

int InstanceNormFp16CPUKernel::Run() {
src_data_ = reinterpret_cast<float16_t *>(in_tensors_[0]->data_c());
dst_data_ = reinterpret_cast<float16_t *>(out_tensors_[0]->data_c());
auto ret = ParallelLaunch(this->context_->thread_pool_, InstanceNormFp16Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InstanceNormFp16Run error error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormFp16CPUKernel>)
} // namespace mindspore::kernel

+ 51
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/instance_norm_fp16.h View File

@@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "nnacl/instance_norm_parameter.h"

using mindspore::lite::InnerContext;

namespace mindspore::kernel {
class InstanceNormFp16CPUKernel : public LiteKernel {
public:
InstanceNormFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<InstanceNormParameter *>(parameter);
}
~InstanceNormFp16CPUKernel() override { FreeTmpBuffer(); };

int Init() override;
int ReSize() override;
int Run() override;
int DoInstanceNorm(int task_id);

private:
void FreeTmpBuffer();
InstanceNormParameter *param_ = nullptr;
float16_t *src_data_ = nullptr;
float16_t *dst_data_ = nullptr;
float16_t *gamma_data_ = nullptr;
float16_t *beta_data_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_

+ 2
- 0
mindspore/lite/src/runtime/kernel/npu/instance_norm_npu.cc View File

@@ -54,6 +54,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input
lite::ConverterToNPUDataType(inputs[1]->data_type()));
gamma_tensor->SetTensorDesc(gamma_tensor_desc);
gamma_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[1]->data_c()), inputs[1]->Size());
gamma->set_attr_value(gamma_tensor);
op_->set_input_gamma(*gamma);

auto beta = new (std::nothrow) hiai::op::Const(name_ + "_beta");
@@ -71,6 +72,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input
lite::ConverterToNPUDataType(inputs[2]->data_type()));
beta_tensor->SetTensorDesc(beta_tensor_desc);
beta_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size());
beta->set_attr_value(beta_tensor);
op_->set_input_beta(*beta);
op_->set_attr_epsilon(instance_norm_param_->epsilon_);
return RET_OK;


Loading…
Cancel
Save