| @@ -159,6 +159,16 @@ int ConvolutionCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { | |||||
| auto conv_parameter = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(conv_parameter, parameter, sizeof(ConvParameter)); | |||||
| return conv_parameter; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | ||||
| @@ -215,6 +225,11 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| for (int i = 0; i < group; ++i) { | for (int i = 0; i < group; ++i) { | ||||
| std::vector<lite::Tensor *> new_inputs; | std::vector<lite::Tensor *> new_inputs; | ||||
| std::vector<lite::Tensor *> new_outputs; | std::vector<lite::Tensor *> new_outputs; | ||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param); | |||||
| if (new_conv_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Get new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // get new input for each group | // get new input for each group | ||||
| auto in_tensor = | auto in_tensor = | ||||
| new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR); | new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR); | ||||
| @@ -253,10 +268,10 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| new_outputs.emplace_back(tmp_out_tensor); | new_outputs.emplace_back(tmp_out_tensor); | ||||
| } | } | ||||
| group_convs.emplace_back( | |||||
| CpuConvFp32KernelSelect(new_inputs, new_outputs, op_parameter, ctx, primitive, use_winograd, out_unit)); | |||||
| group_convs.emplace_back(CpuConvFp32KernelSelect(new_inputs, new_outputs, | |||||
| reinterpret_cast<OpParameter *>(new_conv_parameter), ctx, | |||||
| primitive, use_winograd, out_unit)); | |||||
| } | } | ||||
| // sub kernels and group conv kernel share the same op_parameter struct | |||||
| return new (std::nothrow) | return new (std::nothrow) | ||||
| GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); | GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); | ||||
| } | } | ||||