From: @yangruoqi713 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp16/instance_norm_fp16.h" | |||||
| #include <math.h> | |||||
| #include "nnacl/errorcode.h" | |||||
| int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, | |||||
| const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { | |||||
| if (src_data == NULL || dst_data == NULL) { | |||||
| return NNACL_NULL_PTR; | |||||
| } | |||||
| int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); | |||||
| int channel_begin = task_id * channel_step; | |||||
| int channel_end = MSMIN(channel_begin + channel_step, param->channel_); | |||||
| for (int b = 0; b < param->batch_; b++) { | |||||
| const float16_t *src_b = src_data + b * param->channel_ * param->inner_size_; | |||||
| float16_t *dst_b = dst_data + b * param->channel_ * param->inner_size_; | |||||
| for (int c = channel_begin; c < channel_end; c++) { | |||||
| const float16_t *src = src_b + c * param->inner_size_; | |||||
| float16_t *dst = dst_b + c * param->inner_size_; | |||||
| float mean = 0.0f; | |||||
| float square_mean = 0.0f; | |||||
| int index = 0; | |||||
| for (; index < param->inner_size_ - C8NUM; index += C8NUM) { | |||||
| float16x8_t srcv = vld1q_f16(src + index); | |||||
| float16x8_t squarev = vmulq_f16(srcv, srcv); | |||||
| float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); | |||||
| float32x4_t sum_f32 = vcvt_f32_f16(sum2); | |||||
| mean += vaddvq_f32(sum_f32); | |||||
| float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); | |||||
| float32x4_t square_f32 = vcvt_f32_f16(square2); | |||||
| square_mean += vaddvq_f32(square_f32); | |||||
| } | |||||
| for (; index < param->inner_size_; index++) { | |||||
| mean += src[index]; | |||||
| square_mean += src[index] * src[index]; | |||||
| } | |||||
| mean /= (float)param->inner_size_; | |||||
| square_mean /= (float)param->inner_size_; | |||||
| const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); | |||||
| index = 0; | |||||
| float16x8_t meanv = vdupq_n_f16(mean); | |||||
| float16x8_t denov = vdupq_n_f16(deno); | |||||
| for (; index < param->inner_size_ - C8NUM; index += C8NUM) { | |||||
| float16x8_t srcv = vld1q_f16(src + index); | |||||
| float16x8_t outv = vsubq_f16(srcv, meanv); | |||||
| outv = vmulq_f16(outv, denov); | |||||
| float16x8_t gammav = vdupq_n_f16(gamma_data[c]); | |||||
| float16x8_t betav = vdupq_n_f16(beta_data[c]); | |||||
| outv = vmulq_f16(outv, gammav); | |||||
| outv = vaddq_f16(outv, betav); | |||||
| vst1q_f16(dst + index, outv); | |||||
| } | |||||
| for (; index < param->inner_size_; index++) { | |||||
| dst[index] = (src[index] - mean) * deno; | |||||
| dst[index] = dst[index] * gamma_data[c] + beta_data[c]; | |||||
| } | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ | |||||
| #include "nnacl/instance_norm_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, | |||||
| const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP16_INSTANCE_NORM_H_ | |||||
| @@ -538,6 +538,9 @@ LiteSession::~LiteSession() { | |||||
| MS_LOG(ERROR) << "Not support multi-threading"; | MS_LOG(ERROR) << "Not support multi-threading"; | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto *kernel : kernels_) { | |||||
| delete kernel; | |||||
| } | |||||
| for (size_t i = 0; i < tensors_.size(); i++) { | for (size_t i = 0; i < tensors_.size(); i++) { | ||||
| auto *tensor = tensors_.at(i); | auto *tensor = tensors_.at(i); | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| @@ -552,9 +555,6 @@ LiteSession::~LiteSession() { | |||||
| output_node_map_.clear(); | output_node_map_.clear(); | ||||
| output_tensor_map_.clear(); | output_tensor_map_.clear(); | ||||
| input_vec_.clear(); | input_vec_.clear(); | ||||
| for (auto *kernel : kernels_) { | |||||
| delete kernel; | |||||
| } | |||||
| delete this->context_; | delete this->context_; | ||||
| delete this->executor_; | delete this->executor_; | ||||
| this->executor_ = nullptr; | this->executor_ = nullptr; | ||||
| @@ -141,6 +141,9 @@ void UpdatePreTensors(kernel::LiteKernel *cur_kernel) { | |||||
| void UpdatePostTensors(kernel::LiteKernel *cur_kernel) { | void UpdatePostTensors(kernel::LiteKernel *cur_kernel) { | ||||
| auto tensor = cur_kernel->out_tensors()[0]; | auto tensor = cur_kernel->out_tensors()[0]; | ||||
| tensor->set_format(schema::Format_NCHW); | |||||
| auto nhwc_shape = tensor->shape(); | |||||
| tensor->set_shape({nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]}); | |||||
| for (auto out_kernel : cur_kernel->out_kernels()) { | for (auto out_kernel : cur_kernel->out_kernels()) { | ||||
| auto out_tensor = out_kernel->out_tensors()[0]; | auto out_tensor = out_kernel->out_tensors()[0]; | ||||
| if (out_kernel->out_kernels().empty()) { | if (out_kernel->out_kernels().empty()) { | ||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/instance_norm_fp16.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "nnacl/fp16/cast_fp16.h" | |||||
| #include "nnacl/fp16/instance_norm_fp16.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_InstanceNorm; | |||||
| namespace mindspore::kernel { | |||||
| void InstanceNormFp16CPUKernel::FreeTmpBuffer() { | |||||
| if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { | |||||
| if (gamma_data_ != nullptr) { | |||||
| free(gamma_data_); | |||||
| gamma_data_ = nullptr; | |||||
| } | |||||
| } | |||||
| if (in_tensors_[2]->data_type() == kNumberTypeFloat32) { | |||||
| if (beta_data_ != nullptr) { | |||||
| free(beta_data_); | |||||
| beta_data_ = nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| int InstanceNormFp16CPUKernel::Init() { | |||||
| auto gamma = in_tensors_[1]; | |||||
| if (gamma->data_type() == kNumberTypeFloat32) { | |||||
| gamma_data_ = reinterpret_cast<float16_t *>(malloc(gamma->ElementsNum() * sizeof(float16_t))); | |||||
| if (gamma_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc gamma_data_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| Float32ToFloat16(reinterpret_cast<float *>(gamma->data_c()), gamma_data_, gamma->ElementsNum()); | |||||
| } else if (gamma->data_type() == kNumberTypeFloat16) { | |||||
| gamma_data_ = reinterpret_cast<float16_t *>(gamma->data_c()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported data type of gamma tensor for instance norm."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto beta = in_tensors_[2]; | |||||
| if (beta->data_type() == kNumberTypeFloat32) { | |||||
| beta_data_ = reinterpret_cast<float16_t *>(malloc(beta->ElementsNum() * sizeof(float16_t))); | |||||
| if (beta_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "InstanceNorm fp16 kernel malloc beta_data_ error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| Float32ToFloat16(reinterpret_cast<float *>(beta->data_c()), beta_data_, beta->ElementsNum()); | |||||
| } else if (beta->data_type() == kNumberTypeFloat16) { | |||||
| beta_data_ = reinterpret_cast<float16_t *>(beta->data_c()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported data type of beta tensor for instance norm."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int InstanceNormFp16CPUKernel::ReSize() { | |||||
| param_->op_parameter_.thread_num_ = context_->thread_num_; | |||||
| auto shape = in_tensors_.front()->shape(); | |||||
| param_->batch_ = shape[0]; | |||||
| param_->inner_size_ = shape[2] * shape[3]; | |||||
| param_->channel_ = shape[1]; | |||||
| return RET_OK; | |||||
| } | |||||
| int InstanceNormFp16CPUKernel::DoInstanceNorm(int task_id) { | |||||
| int ret = InstanceNormFp16(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int InstanceNormFp16Run(void *cdata, int task_id) { | |||||
| auto kernel = reinterpret_cast<InstanceNormFp16CPUKernel *>(cdata); | |||||
| auto ret = kernel->DoInstanceNorm(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "InstanceNormFp16Run error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int InstanceNormFp16CPUKernel::Run() { | |||||
| src_data_ = reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()); | |||||
| dst_data_ = reinterpret_cast<float16_t *>(out_tensors_[0]->data_c()); | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, InstanceNormFp16Run, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "InstanceNormFp16Run error error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormFp16CPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "include/context.h" | |||||
| #include "nnacl/instance_norm_parameter.h" | |||||
| using mindspore::lite::InnerContext; | |||||
| namespace mindspore::kernel { | |||||
| class InstanceNormFp16CPUKernel : public LiteKernel { | |||||
| public: | |||||
| InstanceNormFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| param_ = reinterpret_cast<InstanceNormParameter *>(parameter); | |||||
| } | |||||
| ~InstanceNormFp16CPUKernel() override { FreeTmpBuffer(); }; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoInstanceNorm(int task_id); | |||||
| private: | |||||
| void FreeTmpBuffer(); | |||||
| InstanceNormParameter *param_ = nullptr; | |||||
| float16_t *src_data_ = nullptr; | |||||
| float16_t *dst_data_ = nullptr; | |||||
| float16_t *gamma_data_ = nullptr; | |||||
| float16_t *beta_data_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_INSTANCE_NORM_H_ | |||||
| @@ -54,6 +54,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input | |||||
| lite::ConverterToNPUDataType(inputs[1]->data_type())); | lite::ConverterToNPUDataType(inputs[1]->data_type())); | ||||
| gamma_tensor->SetTensorDesc(gamma_tensor_desc); | gamma_tensor->SetTensorDesc(gamma_tensor_desc); | ||||
| gamma_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[1]->data_c()), inputs[1]->Size()); | gamma_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[1]->data_c()), inputs[1]->Size()); | ||||
| gamma->set_attr_value(gamma_tensor); | |||||
| op_->set_input_gamma(*gamma); | op_->set_input_gamma(*gamma); | ||||
| auto beta = new (std::nothrow) hiai::op::Const(name_ + "_beta"); | auto beta = new (std::nothrow) hiai::op::Const(name_ + "_beta"); | ||||
| @@ -71,6 +72,7 @@ int InstanceNormNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &input | |||||
| lite::ConverterToNPUDataType(inputs[2]->data_type())); | lite::ConverterToNPUDataType(inputs[2]->data_type())); | ||||
| beta_tensor->SetTensorDesc(beta_tensor_desc); | beta_tensor->SetTensorDesc(beta_tensor_desc); | ||||
| beta_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size()); | beta_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size()); | ||||
| beta->set_attr_value(beta_tensor); | |||||
| op_->set_input_beta(*beta); | op_->set_input_beta(*beta); | ||||
| op_->set_attr_epsilon(instance_norm_param_->epsilon_); | op_->set_attr_epsilon(instance_norm_param_->epsilon_); | ||||
| return RET_OK; | return RET_OK; | ||||