|
|
|
@@ -159,6 +159,16 @@ int ConvolutionCPUKernel::Run() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { |
|
|
|
auto conv_parameter = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); |
|
|
|
if (conv_parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Malloc new conv parameter failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
memcpy(conv_parameter, parameter, sizeof(ConvParameter)); |
|
|
|
return conv_parameter; |
|
|
|
} |
|
|
|
|
|
|
|
kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, |
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, |
|
|
|
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, |
|
|
|
@@ -215,6 +225,11 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor |
|
|
|
for (int i = 0; i < group; ++i) { |
|
|
|
std::vector<lite::Tensor *> new_inputs; |
|
|
|
std::vector<lite::Tensor *> new_outputs; |
|
|
|
auto new_conv_parameter = CreateNewConvParameter(conv_param); |
|
|
|
if (new_conv_parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Get new conv parameter failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// get new input for each group |
|
|
|
auto in_tensor = |
|
|
|
new (std::nothrow) lite::Tensor(inputs.front()->data_type(), in_shape, Format_NHWC, lite::Tensor::Category::VAR); |
|
|
|
@@ -253,10 +268,10 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor |
|
|
|
new_outputs.emplace_back(tmp_out_tensor); |
|
|
|
} |
|
|
|
|
|
|
|
group_convs.emplace_back( |
|
|
|
CpuConvFp32KernelSelect(new_inputs, new_outputs, op_parameter, ctx, primitive, use_winograd, out_unit)); |
|
|
|
group_convs.emplace_back(CpuConvFp32KernelSelect(new_inputs, new_outputs, |
|
|
|
reinterpret_cast<OpParameter *>(new_conv_parameter), ctx, |
|
|
|
primitive, use_winograd, out_unit)); |
|
|
|
} |
|
|
|
// sub kernels and group conv kernel share the same op_parameter struct |
|
|
|
return new (std::nothrow) |
|
|
|
GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); |
|
|
|
} |
|
|
|
|