diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc index 4eb619562f..2c26d1fa88 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc @@ -23,10 +23,6 @@ Matrix *TransformMatrixGenerator(int m, int k) { auto aa = malloc(m * k * sizeof(float)); matrix->SetData(aa); matrix->SetNum(m, k); - // matrix->data_ = malloc(m * k * sizeof(float)); - // matrix->m_ = m; - // matrix->k_ = k; - // matrix->row_major_ = true; return matrix; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.h b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h index c29d0944ac..f5265728ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matrix.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h @@ -65,26 +65,6 @@ class Matrix { int n_dim_; bool row_major_; }; -// struct Matrix { -// void *data_; -// int *shape_; -// int *stride_; -// int m_; -// int k_; -// int n_dim_; -// bool row_major_; -// ~Matrix() { -// if (data_ != nullptr) { -// free(data_); -// } -// if (shape_ != nullptr) { -// free(shape_); -// } -// if (shape_ != nullptr) { -// free(stride_); -// } -// } -//}; Matrix *TransformMatrixGenerator(int m, int k); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index ccb92e2992..3f365c2ccd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" @@ -265,11 +266,9 @@ int Convolution3x3FP16CPUKernel::Run() { return RET_ERROR; } auto input_tensor = in_tensors_.at(kInputIndex); + auto input_ele_num = input_tensor->ElementsNum(); auto ori_input_data = reinterpret_cast(input_tensor->Data()); - auto input_element_num = input_tensor->ElementsNum(); - for (int i = 0; i < input_element_num; ++i) { - fp16_input_[i] = (float16_t)ori_input_data[i]; - } + Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; @@ -285,12 +284,9 @@ int Convolution3x3FP16CPUKernel::Run() { // cast fp16 out to fp32 data auto out_tensor = out_tensors_.at(kOutputIndex); + auto out_ele_num = out_tensor->ElementsNum(); auto output_addr = reinterpret_cast(out_tensor->Data()); - auto output_element_num = out_tensor->ElementsNum(); - - for (int j = 0; j < output_element_num; ++j) { - output_addr[j] = static_cast(fp16_out_[j]); - } + Float16ToFloat32(fp16_out_, output_addr, out_ele_num); return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 0a6e5d874f..65116d77f9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -15,8 +15,11 @@ */ #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include +#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" #include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" #include "schema/model_generated.h" @@ -231,10 +234,8 @@ int ConvolutionFP16CPUKernel::Run() { } auto input_tensor = in_tensors_.at(kInputIndex); auto ori_input_data = reinterpret_cast(input_tensor->Data()); - auto input_element_num = input_tensor->ElementsNum(); - for (int i = 0; i < input_element_num; ++i) { - fp16_input_[i] = (float16_t)ori_input_data[i]; - } + auto input_ele_num = input_tensor->ElementsNum(); + Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); int in_batch = conv_param_->input_batch_; int in_h = conv_param_->input_h_; @@ -251,10 +252,8 @@ int ConvolutionFP16CPUKernel::Run() { // cast fp16 out to fp32 data auto out_tensor = out_tensors_.at(kOutputIndex); auto output_addr = reinterpret_cast(out_tensor->Data()); - auto output_element_num = out_tensor->ElementsNum(); - for (int j = 0; j < output_element_num; ++j) { - output_addr[j] = static_cast(fp16_out_[j]); - } + auto out_ele_num = out_tensor->ElementsNum(); + Float16ToFloat32(fp16_out_, output_addr, out_ele_num); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc new file mode 100644 index 0000000000..7ef48cd94f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc @@ -0,0 +1,269 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" +#include +#include "src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.h" +#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int ConvolutionSWFP16CPUKernel::ProcessFilter() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + + auto *origin_weight = reinterpret_cast(in_tensors_.at(kWeightIndex)->Data()); + size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = reinterpret_cast(malloc(fp16_weight_size)); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_weight_ failed."; + return RET_ERROR; + } + // cast origin fp32 weight data to fp16 data + for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { + fp16_weight_[i] = (float16_t)origin_weight[i]; + } + + for (int oc = 0; oc < out_channel; ++oc) { + int src_oc_offset = oc * kernel_h * kernel_w * in_channel; + int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM; + for (int i = 0; i < kernel_h * kernel_w; ++i) { + const float16_t *src = fp16_weight_ + src_oc_offset + i * in_channel; + float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM; + memcpy(dst, src, in_channel * sizeof(float16_t)); + } + } + + return RET_OK; +} + +int ConvolutionSWFP16CPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; + + // init weight + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); + auto ret = ProcessFilter(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Process filter failed."; + return ret; + } + + // init bias + bias_data_ = malloc(oc4 * C4NUM * sizeof(float16_t)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc4 * C4NUM * sizeof(float16_t)); + auto fp16_bias_data = reinterpret_cast(bias_data_); + if (in_tensors_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->Data()); + for (int i = 0; i < out_channel; ++i) { + fp16_bias_data[i] = (float16_t)ori_bias[i]; + } + } else { + MS_ASSERT(in_tensor_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionSWFP16CPUKernel::InitTmpBuffer() { + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int channel_block = UP_DIV(in_channel, C4NUM); + int oc4 = UP_DIV(out_channel, C4NUM); + + /*=============================fp16_input_============================*/ + size_t fp16_input_size = + in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = reinterpret_cast(malloc(fp16_input_size)); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + + /*=============================nhwc4_input_============================*/ + size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + /*=============================tmp_output_block_============================*/ + tmp_output_block_ = reinterpret_cast(malloc(conv_param_->output_batch_ * conv_param_->output_h_ * + conv_param_->output_w_ * oc4 * C4NUM * sizeof(float16_t))); + if (tmp_output_block_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; + return RET_ERROR; + } + + /*=============================fp16_out_============================*/ + size_t fp16_output_size = + out_channel * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float16_t); + fp16_out_ = reinterpret_cast(malloc(fp16_output_size)); + if (fp16_out_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_out_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +void ConvolutionSWFP16CPUKernel::ConfigInputOutput() { + auto input_tensor = in_tensors_.at(kInputIndex); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + convert_func_ = LayoutTransformFp16(input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return; + } +} + +int ConvolutionSWFP16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + set_need_reinit(); + return RET_OK; + } + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; + return ret; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + + // init sliding window param + slidingWindow_param_ = new SlidingWindowParam; + InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM); + return RET_OK; +} + +int ConvolutionSWFP16CPUKernel::ReSize() { + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + delete slidingWindow_param_; + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return ret; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + // init sliding window param + slidingWindow_param_ = new SlidingWindowParam; + InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM); + return RET_OK; +} + +int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) { + ConvSWFp16(reinterpret_cast(nhwc4_input_), packed_weight_, reinterpret_cast(bias_data_), + tmp_output_block_, fp16_out_, task_id, conv_param_, slidingWindow_param_); + return RET_OK; +} + +int ConvolutionSWFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionSWFP16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + auto input_tensor = in_tensors_.at(kInputIndex); + auto input_ele_num = input_tensor->ElementsNum(); + auto ori_input_data = reinterpret_cast(input_tensor->Data()); + Float32ToFloat16(ori_input_data, fp16_input_, input_ele_num); + + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(reinterpret_cast(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionSWFp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } + + // cast fp16 out to fp32 data + auto out_tensor = out_tensors_.at(kOutputIndex); + auto out_ele_num = out_tensor->ElementsNum(); + auto output_addr = reinterpret_cast(out_tensor->Data()); + Float16ToFloat32(fp16_out_, output_addr, out_ele_num); + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h new file mode 100644 index 0000000000..08239853a6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" + +namespace mindspore::kernel { +class ConvolutionSWFP16CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ConvolutionSWFP16CPUKernel() override { + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (fp16_out_ != nullptr) { + free(fp16_out_); + } + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + delete slidingWindow_param_; + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + int ProcessFilter(); + + private: + float16_t *fp16_input_; + float16_t *fp16_weight_; + float16_t *fp16_out_; + float16_t *packed_weight_; + float16_t *tmp_output_block_; + SlidingWindowParam *slidingWindow_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc index 91f302e102..bcf3665763 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc @@ -259,6 +259,21 @@ int Convolution3x3CPUKernel::Run() { MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]"; return RET_ERROR; } + + auto is_relu = conv_param_->is_relu_; + auto is_relu6 = conv_param_->is_relu6_; + auto output_addr = reinterpret_cast(out_tensors_.at(kOutputIndex)->Data()); + PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + int output_num = + conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; + if (is_relu) { + ReluFp32(output_addr, output_addr, output_num); + } else if (is_relu6) { + Relu6Fp32(output_addr, output_addr, output_num); + } else { + // do nothing + } return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index 3d7be4f2a1..2a3be98743 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -189,8 +189,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { /*=============================tmp_out_data_============================*/ int out_w_block = UP_DIV(output_w, output_unit_); int out_h_block = UP_DIV(output_h, output_unit_); - tmp_out_data_ = reinterpret_cast( - malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); + tmp_out_data_ = reinterpret_cast(malloc(conv_param_->output_batch_ * out_w_block * out_h_block * + output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); if (tmp_out_data_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; return RET_ERROR; @@ -365,6 +365,22 @@ int ConvolutionWinogradCPUKernel::Run() { MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; return RET_ERROR; } + + // get real output + auto out_tensor = out_tensors_.front(); + auto out_data = reinterpret_cast(out_tensor->Data()); + UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_, + conv_param_->output_w_, conv_param_->output_channel_, output_unit_); + int output_num = + conv_param_->output_channel_ * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_; + if (conv_param_->is_relu_) { + ReluFp32(out_data, out_data, output_num); + } else if (conv_param_->is_relu6_) { + Relu6Fp32(out_data, out_data, output_num); + } else { + // do nothing + } + return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.c index 09e19c040d..0d7b8ca211 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.c @@ -18,7 +18,6 @@ #include "nnacl/fp16/pack_fp16.h" #include "nnacl/fp16/winograd_transform_fp16.h" - #ifdef __cplusplus extern "C" { #endif @@ -112,6 +111,209 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w } #endif +void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic, bool is_relu, + bool is_relu6) { + int ic8 = ic / C8NUM; + int ic8_res = ic8 % C8NUM; + int ic4 = ic8_res / C4NUM; + for (int c = 0; c < C4NUM; c++) { + dst[c] = 0; + } + const float16_t *weight_oc = weight; + for (int oc = 0; oc < C4NUM; ++oc) { + const float16_t *weight_kh = weight_oc; + const float16_t *src_kh = src; + for (int kh = 0; kh < height; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + const float16_t *src_ic8 = src_kw; + const float16_t *weight_ic8 = weight_kw; + + for (int rc = 0; rc < ic8; ++rc) { + for (int c = 0; c < C8NUM; c++) { + dst[oc] += src_ic8[c] * weight_ic8[c]; + } + src_ic8 += C8NUM; + weight_ic8 += C8NUM; + } // ic8 loop + + const float16_t *src_ic4 = src_ic8; + const float16_t *weight_ic4 = weight_ic8; + for (int rc = 0; rc < ic4; ++rc) { + for (int c = 0; c < C4NUM; c++) { + dst[oc] += src_ic4[c] * weight_ic4[c]; + } + src_ic4 += C4NUM; + weight_ic4 += C4NUM; + } // ic4 loop + + src_kw += in_kw_step; + weight_kw += ic4 * C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * ic4 * C4NUM; + } // kernel_h loop + dst[oc] += bias[oc]; + dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]); + dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]); + weight_oc += kernel_h * kernel_w * ic4 * C4NUM; + } // oc loop +} + +void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + float16_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float16_t *src_h = src + ih * sliding->in_h_step_; + + float16_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float16_t *src_w = src_h + iw * sliding->ic4_channel_; + + const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_; + + SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, + int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + int ic8 = ic / C8NUM; + int ic8_res = ic % C8NUM; + int ic4 = ic8_res / C4NUM; + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float16_t *weight_oc = weight; + for (int c = 0; c < C4NUM; c++) { + dst_w[c] = 0; + } + + for (int oc = 0; oc < C4NUM; oc++) { + const float16_t *weight_kh = weight_oc; + const float16_t *src_kh = src_w; + for (int kh = 0; kh < kernel_h; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + const float16_t *src_ic8 = src_kw; + const float16_t *weight_ic8 = weight_kw; + + for (int rc = 0; rc < ic8; ++rc) { + for (int c = 0; c < C8NUM; c++) { + dst_w[oc] += src_ic8[c] * weight_ic8[c]; + } + + src_ic8 += C8NUM; + weight_ic8 += C8NUM; + } // ic8 loop + + const float16_t *src_ic4 = src_ic8; + const float16_t *weight_ic4 = weight_ic8; + for (int rc = 0; rc < ic4; ++rc) { + for (int c = 0; c < C4NUM; c++) { + dst_w[oc] += src_ic4[c] * weight_ic4[c]; + } + + src_ic4 += C4NUM; + weight_ic4 += C4NUM; + } // ic4 loop + + src_kw += in_kw_step; + weight_kw += ic4 * C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * ic4 * C4NUM; + } // kernel_h loop + // add biad relu + + dst_w[oc] += bias[oc]; + dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]); + dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]); + weight_oc += kernel_h * kernel_w * ic4 * C4NUM; + } // oc block + + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +// fp16 conv sliding window +void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, + SlidingWindowParam *slidingWindow_param) { + int oc4_res = conv_param->output_channel_ % C4NUM; + const float16_t *src = input_data; + float16_t *dst; + if (oc4_res == 0) { + dst = output_data; + } else { + dst = tmp_out_block; + } + + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src; + float16_t *dst_data = dst + oc * C4NUM; + const float16_t *weight = packed_weight + oc * slidingWindow_param->kernel_step_; + const float16_t *bias = bias_data + oc * C4NUM; + SWBorderFp16(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param, + slidingWindow_param); + SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, slidingWindow_param); + SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0, + slidingWindow_param->left_, conv_param, slidingWindow_param); + SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, + slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param); + + if (slidingWindow_param->right_ > slidingWindow_param->left_ && + slidingWindow_param->bottom_ > slidingWindow_param->top_) { + int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const float16_t *in_t = + src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_; + float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ + + slidingWindow_param->left_ * slidingWindow_param->block_channel_; + SWCenterFp16(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_, + slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, + conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, + slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_, + slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_, + slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_); + } + } // output C4 loop + src += slidingWindow_param->in_step_; + dst += slidingWindow_param->out_step_; + } // batch loop + // output nhwc4 + if (oc4_res != 0) { + PackNHWC4ToNHWCFp16((const void *)tmp_out_block, (void *)output_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); + } +} + // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { @@ -144,7 +346,7 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * in_channel * in_h * in_w; + int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; int out_batch_offset = b * out_channel * out_h * out_w; int gemm_in_batch_offset = b * packed_input_size; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { @@ -172,7 +374,6 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, int task_id, ConvParameter *conv_param) { - // todo int thread_count = conv_param->thread_num_; int tile_num = 16; const int output_unit = 4; @@ -195,6 +396,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 int input_batch = conv_param->input_batch_; for (int batch = 0; batch < input_batch; batch++) { + int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { int start_index = thread_id * tile_num; int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; @@ -207,8 +410,8 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); - Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, - real_cal_num, out_w_block, conv_param); + Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); } } @@ -217,7 +420,10 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 bool relu = conv_param->is_relu_; bool relu6 = conv_param->is_relu6_; for (int batch = 0; batch < output_batch; batch++) { - int batch_size = batch * output_channel * output_h * output_w; + int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit; + int ro_batch_size = batch * output_channel * output_h * output_w; + const float16_t *batch_tmp_out = tmp_out + tmp_out_batch_offset; + float16_t *batch_out = output_data + ro_batch_size; for (int h = 0; h < output_h; h++) { for (int w = 0; w < output_w; w++) { for (int c = 0; c < output_channel; c++) { @@ -226,12 +432,12 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + C8NUM * (h * out_w_block * output_unit + w) + oc8_res; int dst_offset = (h * output_w + w) * output_channel + c; - (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; + (batch_out + dst_offset)[0] = (batch_tmp_out + src_offset)[0]; if (relu) { - (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; + (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; } else if (relu6) { - (output_data + dst_offset)[0] = (output_data + dst_offset)[0] < 0 ? 0 : (output_data + dst_offset)[0]; - (output_data + dst_offset)[0] = (output_data + dst_offset)[0] > 6 ? 6 : (output_data + dst_offset)[0]; + (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] < 0 ? 0 : (batch_out + dst_offset)[0]; + (batch_out + dst_offset)[0] = (batch_out + dst_offset)[0] > 6 ? 6 : (batch_out + dst_offset)[0]; } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h index 9bb9c7a27e..a6a6e5674e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.h @@ -28,6 +28,18 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh #ifdef __cplusplus extern "C" { #endif +void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding); + +void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, + int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); + +// fp16 sliding window +void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, + SlidingWindowParam *slidingWindow_param); + // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c index 6efec415a3..0bbd701d0e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.c @@ -219,6 +219,24 @@ void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int c } } +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel, + (float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy((float16_t *)dst, (float16_t *)src, ori_input_size); + } +} + void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { int nhwc4_batch_offset = 0; int ic4 = UP_DIV(channel, C4NUM); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h index ac4fc51ab8..349f97b29b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h @@ -41,6 +41,8 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.c index 54eeeb48de..0a764b94b3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.c @@ -217,7 +217,7 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons size_t output_offset = out_channel * sizeof(float); for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * in_channel * in_h * in_w; + int in_batch_offset = b * ic4 * C4NUM * in_h * in_w; int out_batch_offset = b * out_channel * out_h * out_w; int gemm_in_batch_offset = b * packed_input_size; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { @@ -263,12 +263,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, TILE_NUM); int out_channel = conv_param->output_channel_; - int out_batch = conv_param->output_batch_; int oc4 = UP_DIV(out_channel, C4NUM); int input_unit_square = input_unit * input_unit; size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float); - bool is_relu = conv_param->is_relu_; - bool is_relu6 = conv_param->is_relu6_; float *trans_input = buffer_list[0]; float *gemm_out = buffer_list[1]; @@ -280,11 +277,13 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { int out_tile_index = thread_id * TILE_NUM; int cal_num = output_count - thread_id * TILE_NUM; cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; - WinogradInputTransform(input_data, trans_input + task_id * trans_input_offset, + WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, input_trans_func); // step 3 : gemm @@ -292,21 +291,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0); // step 4 : output transform - WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data, bias_data, cal_num, out_tile_index, - out_w_block, conv_param, output_trans_func); + WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data, + cal_num, out_tile_index, out_w_block, conv_param, output_trans_func); } } - // get real output - UnPackWinogradOutput(tmp_out_data, output_data, out_batch, conv_param->output_h_, conv_param->output_w_, out_channel, - out_unit); - int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; - if (is_relu) { - ReluFp32(output_data, output_data, output_num); - } else if (is_relu6) { - Relu6Fp32(output_data, output_data, output_num); - } else { - // do nothing - } } void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, @@ -360,8 +348,6 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, TILE_NUM); int input_unit_square = 4 * 4; - bool is_relu = conv_param->is_relu_; - bool is_relu6 = conv_param->is_relu6_; float *tile_buffer = buffer_list[0]; float *block_unit_buffer = buffer_list[1]; float *tmp_dst_buffer = buffer_list[2]; @@ -372,10 +358,13 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat int input_batch = conv_param->input_batch_; for (int batch = 0; batch < input_batch; batch++) { + int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; + int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { int start_index = thread_id * TILE_NUM; int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; - Conv3x3Fp32InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, out_w_block, conv_param); @@ -383,17 +372,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM, oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); - Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out, bias_data, start_index, - real_cal_num, out_w_block, conv_param); + Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); } - PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); - } - int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; - if (is_relu) { - ReluFp32(output_data, output_data, output_num); - } else if (is_relu6) { - Relu6Fp32(output_data, output_data, output_num); - } else { - // do nothing } } diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg index d87713a150..622861c5f0 100644 --- a/mindspore/lite/test/models_tflite.cfg +++ b/mindspore/lite/test/models_tflite.cfg @@ -1,6 +1,6 @@ hiai_model_0909_kd_rot_ps_softmax.tflite hiai_chinese_english_recognize_model_float32.tflite -#hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite +hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite hiai_cn_recognize_modify_padv2.tflite hiai_model_normalize_object_scene_ps_20200519.tflite