| @@ -182,6 +182,16 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||||
| if (input_node->isa<Parameter>()) { | if (input_node->isa<Parameter>()) { | ||||
| auto param_node = input_node->cast<ParameterPtr>(); | auto param_node = input_node->cast<ParameterPtr>(); | ||||
| ConvertConvWeight<float>(param_node); | ConvertConvWeight<float>(param_node); | ||||
| auto abstractBase = param_node->abstract(); | |||||
| MS_ASSERT(abstractBase != nullptr); | |||||
| if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| MS_ASSERT(abstractTensor != nullptr); | |||||
| if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | |||||
| auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| attr->channelIn = dims[kAnfPopulaterOne]; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | ||||