Browse Source

!12193 fix Conv2d incorrect in_channel size error message

From: @tom__chen
Reviewed-by: @robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4b8d983bf7
2 changed files with 6 additions and 4 deletions
  1. +4
    -3
      mindspore/core/abstract/prim_nn.cc
  2. +2
    -1
      tests/ut/python/model/test_lenet_core_after_exception.py

+ 4
- 3
mindspore/core/abstract/prim_nn.cc View File

@@ -333,9 +333,10 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
w_axis = 2;
}
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis]
<< " (channels) must be divisible by group = " << group;
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
((x_shape[c_axis] / group) != w_shape[c_axis])) {
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got "
<< (x_shape[c_axis] / group);
}
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {


+ 2
- 1
tests/ut/python/model/test_lenet_core_after_exception.py View File

@@ -53,5 +53,6 @@ def test_lenet5_exception():
predict = Tensor(in1)
label = Tensor(in2)
net = train_step_with_loss_warp(LeNet5())
with pytest.raises(ValueError):
with pytest.raises(RuntimeError) as info:
_executor.compile(net, predict, label)
assert "x_shape[C_in] / group must equal to w_shape[C_in] = " in str(info.value)

Loading…
Cancel
Save