Merge pull request !7606 from fuzhiye/tmptags/v1.1.0
| @@ -34,10 +34,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { | |||||
| free(bias_data_); | free(bias_data_); | ||||
| bias_data_ = nullptr; | bias_data_ = nullptr; | ||||
| } | } | ||||
| if (nhwc4_input_ != nullptr) { | |||||
| free(nhwc4_input_); | |||||
| nhwc4_input_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| void ConvolutionBaseCPUKernel::FreeQuantParam() { | void ConvolutionBaseCPUKernel::FreeQuantParam() { | ||||
| @@ -112,18 +108,6 @@ int ConvolutionBaseCPUKernel::CheckResizeValid() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionBaseCPUKernel::CheckLayout(lite::Tensor *input_tensor) { | |||||
| auto data_type = input_tensor->data_type(); | |||||
| auto input_format = input_tensor->GetFormat(); | |||||
| schema::Format execute_format = schema::Format::Format_NHWC4; | |||||
| convert_func_ = LayoutTransform(data_type, input_format, execute_format); | |||||
| if (convert_func_ == nullptr) { | |||||
| MS_LOG(ERROR) << "layout convert func is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionBaseCPUKernel::SetIfPerChannel() { | int ConvolutionBaseCPUKernel::SetIfPerChannel() { | ||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | auto filter_tensor = in_tensors_.at(kWeightIndex); | ||||
| auto input_channel = filter_tensor->Channel(); | auto input_channel = filter_tensor->Channel(); | ||||
| @@ -48,7 +48,6 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override { return 0; } | int ReSize() override { return 0; } | ||||
| int Run() override { return 0; } | int Run() override { return 0; } | ||||
| virtual int CheckLayout(lite::Tensor *input_tensor); | |||||
| int SetIfAsymmetric(); | int SetIfAsymmetric(); | ||||
| int SetIfPerChannel(); | int SetIfPerChannel(); | ||||
| int MallocQuantParam(); | int MallocQuantParam(); | ||||
| @@ -61,14 +60,12 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||||
| void FreeQuantParam(); | void FreeQuantParam(); | ||||
| protected: | protected: | ||||
| int tile_num_; | |||||
| void *bias_data_ = nullptr; | void *bias_data_ = nullptr; | ||||
| void *nhwc4_input_ = nullptr; | |||||
| const InnerContext *ctx_; | const InnerContext *ctx_; | ||||
| int thread_count_; | |||||
| ConvParameter *conv_param_; | ConvParameter *conv_param_; | ||||
| ConvQuantArg *conv_quant_arg_; | ConvQuantArg *conv_quant_arg_; | ||||
| LayoutConvertor convert_func_ = nullptr; | |||||
| int tile_num_; | |||||
| int thread_count_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -61,6 +61,10 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | ||||
| RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); | RowMajor2Col8MajorFp16(execute_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); | ||||
| if (fp16_weight_ != nullptr) { | |||||
| free(fp16_weight_); | |||||
| fp16_weight_ = nullptr; | |||||
| } | |||||
| // init bias | // init bias | ||||
| bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); | bias_data_ = malloc(oc8 * C8NUM * sizeof(float16_t)); | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "src/runtime/kernel/arm/fp32/convolution.h" | #include "src/runtime/kernel/arm/fp32/convolution.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_1x1.h" | #include "src/runtime/kernel/arm/fp32/convolution_1x1.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_winograd.h" | #include "src/runtime/kernel/arm/fp32/convolution_winograd.h" | ||||
| #include "src/runtime/kernel/arm/fp32/group_convolution.h" | |||||
| #include "nnacl/fp32/conv.h" | #include "nnacl/fp32/conv.h" | ||||
| #include "nnacl/common_func.h" | #include "nnacl/common_func.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| @@ -31,6 +32,7 @@ using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_INFER_INVALID; | using mindspore::lite::RET_INFER_INVALID; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Conv2D; | using mindspore::schema::PrimitiveType_Conv2D; | ||||
| using mindspore::schema::Format::Format_NHWC; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionCPUKernel::InitWeightBias() { | int ConvolutionCPUKernel::InitWeightBias() { | ||||
| @@ -157,6 +159,108 @@ int ConvolutionCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | |||||
| bool use_winograd, int out_unit) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | |||||
| return new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } else if (use_winograd) { | |||||
| return new (std::nothrow) | |||||
| kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); | |||||
| } else { | |||||
| return new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | |||||
| int group) { | |||||
| std::vector<kernel::LiteKernel *> group_convs; | |||||
| std::vector<int> in_shape; | |||||
| std::vector<int> filter_shape; | |||||
| std::vector<int> bias_shape; | |||||
| std::vector<int> out_shape; | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| int out_channel = inputs.at(kWeightIndex)->Batch(); | |||||
| int new_in_channel = inputs.at(kWeightIndex)->Channel(); | |||||
| int new_out_channel = out_channel / group; | |||||
| int kernel_h = conv_param->kernel_h_; | |||||
| int kernel_w = conv_param->kernel_w_; | |||||
| int input_num = inputs.size(); | |||||
| int output_num = outputs.size(); | |||||
| bool has_bias = input_num == 3; | |||||
| bool use_winograd = false; | |||||
| int out_unit; | |||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||||
| int batch = inputs.front()->Batch(); | |||||
| int in_h = inputs.front()->Height(); | |||||
| int in_w = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = new_in_channel; | |||||
| conv_param->output_channel_ = new_out_channel; | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||||
| in_shape = {batch, in_h, in_w, new_in_channel}; | |||||
| out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; | |||||
| } | |||||
| filter_shape = {new_out_channel, kernel_h, kernel_w, new_in_channel}; | |||||
| bias_shape = {new_out_channel}; | |||||
| auto *origin_weight = reinterpret_cast<float *>(inputs.at(kWeightIndex)->data_c()); | |||||
| auto *origin_bias = reinterpret_cast<float *>(inputs.at(kBiasIndex)->data_c()); | |||||
| for (int i = 0; i < group; ++i) { | |||||
| std::vector<lite::Tensor *> new_inputs; | |||||
| std::vector<lite::Tensor *> new_outputs; | |||||
| // get new input for each group | |||||
| auto in_tensor = | |||||
| new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR); | |||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||||
| in_tensor->MallocData(); | |||||
| } | |||||
| new_inputs.emplace_back(in_tensor); | |||||
| // nwe weight | |||||
| auto filter_tensor = new (std::nothrow) | |||||
| lite::Tensor(inputs.at(kWeightIndex)->data_type(), filter_shape, Format_NHWC, lite::Tensor::Category::CONST); | |||||
| filter_tensor->MallocData(); | |||||
| int copy_length = kernel_h * kernel_w * new_in_channel * new_out_channel; | |||||
| memcpy(filter_tensor->data_c(), origin_weight + i * copy_length, copy_length * sizeof(float)); | |||||
| new_inputs.emplace_back(filter_tensor); | |||||
| // if has bias, set new bias | |||||
| if (has_bias) { | |||||
| auto bias_tensor = new (std::nothrow) | |||||
| lite::Tensor(inputs.at(kBiasIndex)->data_type(), bias_shape, Format_NHWC, lite::Tensor::Category::CONST); | |||||
| bias_tensor->MallocData(); | |||||
| memcpy(bias_tensor->data_c(), origin_bias + i * new_out_channel, new_out_channel * sizeof(float)); | |||||
| new_inputs.emplace_back(bias_tensor); | |||||
| } | |||||
| // set new output tensor | |||||
| for (int j = 0; j < output_num; ++j) { | |||||
| auto tmp_out_tensor = new (std::nothrow) lite::Tensor(); | |||||
| tmp_out_tensor->set_data_type(outputs.at(j)->data_type()); | |||||
| tmp_out_tensor->SetFormat(outputs.at(j)->GetFormat()); | |||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||||
| tmp_out_tensor->set_shape(out_shape); | |||||
| tmp_out_tensor->MallocData(); | |||||
| } | |||||
| new_outputs.emplace_back(tmp_out_tensor); | |||||
| } | |||||
| group_convs.emplace_back( | |||||
| CpuConvFp32KernelSelect(new_inputs, new_outputs, op_parameter, ctx, primitive, use_winograd, out_unit)); | |||||
| } | |||||
| // sub kernels and group conv kernel share the same op_parameter struct | |||||
| return new (std::nothrow) | |||||
| GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); | |||||
| } | |||||
| kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const InnerContext *ctx, const kernel::KernelKey &desc, | const InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| @@ -164,8 +268,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| MS_ASSERT(op_parameter != nullptr); | MS_ASSERT(op_parameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | ||||
| int kernel_h = conv_param->kernel_h_; | |||||
| int kernel_w = conv_param->kernel_w_; | |||||
| int group = conv_param->group_; | |||||
| bool use_winograd = false; | bool use_winograd = false; | ||||
| int out_unit; | int out_unit; | ||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | if (primitive != nullptr && primitive->GetInferFlag()) { | ||||
| @@ -192,14 +295,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| } | } | ||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| if (kernel_h == 1 && kernel_w == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } else if (use_winograd) { | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit); | |||||
| if (group == 1) { | |||||
| kernel = CpuConvFp32KernelSelect(inputs, outputs, op_parameter, ctx, primitive, use_winograd, out_unit); | |||||
| } else { | } else { | ||||
| kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx, primitive, group); | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) { | ||||
| @@ -0,0 +1,150 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/group_convolution.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | |||||
| int GroupConvolutionCPUKernel::Init() { | |||||
| for (int i = 0; i < group_num_; ++i) { | |||||
| auto ret = group_convs_[i]->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Sub kernel init failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| // if infer shape is done, resize func will be invoked in sub kernels | |||||
| return RET_OK; | |||||
| } | |||||
| int GroupConvolutionCPUKernel::ReSize() { | |||||
| for (int i = 0; i < group_num_; ++i) { | |||||
| auto ret = group_convs_[i]->ReSize(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Sub kernel resize failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| conv_param_->input_channel_ /= group_num_; | |||||
| conv_param_->output_channel_ /= group_num_; | |||||
| return RET_OK; | |||||
| } | |||||
| int GroupConvolutionCPUKernel::PreProcess() { | |||||
| if (!InferShapeDone()) { | |||||
| auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_); | |||||
| if (ret != 0) { | |||||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false); | |||||
| MS_LOG(ERROR) << "InferShape fail!"; | |||||
| return ret; | |||||
| } | |||||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true); | |||||
| ret = ReSize(); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "ReSize fail!ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| // if infershape func is called in runtime stage, we should malloc memory and set shape info for outputs of sub | |||||
| // kernels here. | |||||
| std::vector<int> in_shape; | |||||
| std::vector<int> out_shape; | |||||
| for (int i = 0; i < group_num_; ++i) { | |||||
| // in | |||||
| int in_batch = conv_param_->input_batch_; | |||||
| int in_h = conv_param_->input_h_; | |||||
| int in_w = conv_param_->input_w_; | |||||
| int in_c = conv_param_->input_channel_; | |||||
| in_shape = {in_batch, in_h, in_w, in_c}; | |||||
| auto sub_kernel_in_tensor = group_convs_[i]->in_tensors().front(); | |||||
| sub_kernel_in_tensor->set_shape(in_shape); | |||||
| sub_kernel_in_tensor->MallocData(); | |||||
| // out | |||||
| int out_batch = conv_param_->output_batch_; | |||||
| int out_h = conv_param_->output_h_; | |||||
| int out_w = conv_param_->output_w_; | |||||
| int out_c = conv_param_->output_channel_; | |||||
| out_shape = {out_batch, out_h, out_w, out_c}; | |||||
| auto sub_kernel_out_tensors = group_convs_[i]->out_tensors(); | |||||
| for (auto tensor : sub_kernel_out_tensors) { | |||||
| tensor->set_shape(out_shape); | |||||
| tensor->MallocData(); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto outputs = this->out_tensors(); | |||||
| for (auto *output : outputs) { | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->MallocData(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void GroupConvolutionCPUKernel::SeparateInput(int group_id) { | |||||
| int in_h = conv_param_->input_h_; | |||||
| int in_w = conv_param_->input_w_; | |||||
| int in_plane = in_h * in_w; | |||||
| int sub_in_channel = conv_param_->input_channel_; | |||||
| int ori_in_channel = sub_in_channel * group_num_; | |||||
| auto sub_in_data = reinterpret_cast<float *>(group_convs_[group_id]->in_tensors().front()->data_c()); | |||||
| float *src_ptr = ori_in_data_ + group_id * sub_in_channel; | |||||
| float *dst_ptr = sub_in_data; | |||||
| for (int i = 0; i < in_plane; ++i) { | |||||
| memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(float)); | |||||
| src_ptr += ori_in_channel; | |||||
| dst_ptr += sub_in_channel; | |||||
| } | |||||
| } | |||||
| void GroupConvolutionCPUKernel::PostConcat(int group_id) { | |||||
| int out_h = conv_param_->output_h_; | |||||
| int out_w = conv_param_->output_w_; | |||||
| int out_plane = out_h * out_w; | |||||
| int sub_out_channel = conv_param_->output_channel_; | |||||
| int ori_out_channel = sub_out_channel * group_num_; | |||||
| auto sub_out_data = reinterpret_cast<float *>(group_convs_[group_id]->out_tensors().front()->data_c()); | |||||
| float *src_ptr = sub_out_data; | |||||
| float *dst_ptr = ori_out_data_ + group_id * sub_out_channel; | |||||
| for (int i = 0; i < out_plane; ++i) { | |||||
| memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(float)); | |||||
| src_ptr += sub_out_channel; | |||||
| dst_ptr += ori_out_channel; | |||||
| } | |||||
| } | |||||
| int GroupConvolutionCPUKernel::Run() { | |||||
| ori_in_data_ = reinterpret_cast<float *>(in_tensors().front()->data_c()); | |||||
| ori_out_data_ = reinterpret_cast<float *>(out_tensors().front()->data_c()); | |||||
| for (int i = 0; i < group_num_; ++i) { | |||||
| // first, separate group conv input into several parts. This step must be in runtime stage. | |||||
| SeparateInput(i); | |||||
| // sun kernels run | |||||
| group_convs_[i]->Run(); | |||||
| // post process, concat all outputs of sub-kernels into one output | |||||
| PostConcat(i); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/op_base.h" | |||||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||||
| #include "nnacl/fp32/conv.h" | |||||
| namespace mindspore::kernel { | |||||
| class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| public: | |||||
| GroupConvolutionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive, std::vector<kernel::LiteKernel *> group_convs, | |||||
| const int group_num) | |||||
| : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), | |||||
| group_convs_(std::move(group_convs)), | |||||
| group_num_(group_num) {} // opParameter(in channel, out channel) in this kernel has been split to groups, if | |||||
| // you want to get real params, multiply in channel / out channel with group num | |||||
| ~GroupConvolutionCPUKernel() override { | |||||
| for (auto sub_conv : group_convs_) { | |||||
| // free sub conv input tensors / output tensors manually | |||||
| auto sub_in_tensors = sub_conv->in_tensors(); | |||||
| auto sub_in_tensor_num = sub_in_tensors.size(); | |||||
| for (size_t i = 0; i < sub_in_tensor_num; ++i) { | |||||
| delete sub_in_tensors[i]; | |||||
| } | |||||
| auto sub_out_tensors = sub_conv->out_tensors(); | |||||
| auto sub_out_tensor_num = sub_out_tensors.size(); | |||||
| for (size_t i = 0; i < sub_out_tensor_num; ++i) { | |||||
| delete sub_out_tensors[i]; | |||||
| } | |||||
| delete sub_conv; | |||||
| } | |||||
| }; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int PreProcess() override; | |||||
| void SeparateInput(int group_id); | |||||
| void PostConcat(int group_id); | |||||
| private: | |||||
| std::vector<kernel::LiteKernel *> group_convs_; | |||||
| float *ori_in_data_ = nullptr; // do not free | |||||
| float *ori_out_data_ = nullptr; // do not free | |||||
| const int group_num_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GROUP_CONVOLUTION_H_ | |||||