|
|
|
@@ -333,9 +333,10 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
w_axis = 2; |
|
|
|
} |
|
|
|
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); |
|
|
|
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { |
|
|
|
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] |
|
|
|
<< " (channels) must be divisible by group = " << group; |
|
|
|
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) && |
|
|
|
((x_shape[c_axis] / group) != w_shape[c_axis])) { |
|
|
|
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got " |
|
|
|
<< (x_shape[c_axis] / group); |
|
|
|
} |
|
|
|
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); |
|
|
|
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { |
|
|
|
|