Browse Source

support None argument for the outermost net

tags/v1.6.0
huanghui 4 years ago
parent
commit
1af32f74f9
3 changed files with 13 additions and 13 deletions
  1. +3
    -3
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  2. +5
    -5
      tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py
  3. +5
    -5
      tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py

+ 3
- 3
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -144,8 +144,8 @@ bool CheckArgValid(const py::handle &arg) {
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
}

return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) ||
(py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
py::isinstance<Number>(arg) || (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__"));
}

std::string GetCompileExceptionInfo() {
@@ -235,7 +235,7 @@ void CheckArgsValid(const py::tuple &args) {
for (size_t i = 0; i < args.size(); i++) {
if (!CheckArgValid(args[i])) {
MS_EXCEPTION(TypeError)
<< "The inputs types of the outermost network support bool, int, float, tensor, "
<< "The inputs types of the outermost network support bool, int, float, None, tensor, "
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), "
"and tuple or list containing only these types, and dict whose values are these types, but the "
<< i << "th arg type is " << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";


+ 5
- 5
tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py View File

@@ -95,7 +95,7 @@ def test_grad_first_input_net():
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
@@ -104,7 +104,7 @@ def test_net_inputs_including_str():
def test_outermost_net_pass_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
@@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
def test_outermost_net_pass_tuple_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
@@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
def test_outermost_net_pass_list_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
@@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
def test_grad_net_pass_dict_including_parameter():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)

+ 5
- 5
tests/ut/python/pynative_mode/test_outermost_non_tensor_input.py View File

@@ -95,7 +95,7 @@ def test_grad_first_input_net():
def test_net_inputs_including_str():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
@@ -104,7 +104,7 @@ def test_net_inputs_including_str():
def test_outermost_net_pass_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
@@ -115,7 +115,7 @@ def test_outermost_net_pass_parameter():
def test_outermost_net_pass_tuple_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p))
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 6th arg type is <class 'tuple'>, value is '(" in str(err.value)
@@ -124,7 +124,7 @@ def test_outermost_net_pass_tuple_including_parameter():
def test_outermost_net_pass_list_including_parameter():
with pytest.raises(TypeError) as err:
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 4th arg type is <class 'list'>, value is '[" in str(err.value)
@@ -133,7 +133,7 @@ def test_outermost_net_pass_list_including_parameter():
def test_grad_net_pass_dict_including_parameter():
with pytest.raises(TypeError) as err:
grad_all_inputs_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0)
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
assert "The inputs types of the outermost network support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 3th arg type is <class 'dict'>, value is '{" in str(err.value)

Loading…
Cancel
Save