|
|
|
@@ -182,6 +182,16 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT |
|
|
|
if (input_node->isa<Parameter>()) { |
|
|
|
auto param_node = input_node->cast<ParameterPtr>(); |
|
|
|
ConvertConvWeight<float>(param_node); |
|
|
|
auto abstractBase = param_node->abstract(); |
|
|
|
MS_ASSERT(abstractBase != nullptr); |
|
|
|
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { |
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); |
|
|
|
MS_ASSERT(abstractTensor != nullptr); |
|
|
|
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { |
|
|
|
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); |
|
|
|
attr->channelIn = dims[kAnfPopulaterOne]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; |
|
|
|
|