diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 03c940f1b9..bcf1cc4247 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -42,6 +42,7 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} + def mean(x, axis=(), keep_dims=False): """ Reduces a dimension of a tensor by averaging all elements in the dimension. @@ -218,11 +219,11 @@ def swapaxes(x, axis1, axis2): perm = F.make_range(0, x.ndim) new_perm = None if axis2 + 1 < x.ndim: - new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ - perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] + new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ + perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] else: - new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ - perm[axis1+1:axis2] + perm[axis1:axis1+1] + new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ + perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] return F.transpose(x, new_perm) @@ -343,7 +344,7 @@ def isinstance_(x, base_type): def while_cond(x): - """For while condtion, if the condition is a tensor, the loop will not be unrolled""" + """For while condition, if the condition is a tensor, the loop will not be unrolled""" if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): is_cond = check_is_tensor_bool_cond(F.shape(x)) if is_cond: @@ -373,7 +374,8 @@ def check_type_same(x_type, base_type): target_type = pytype_to_mstype[base_type] return isinstance(x_type, target_type) except KeyError: - raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") + raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " + f"Tensor, Parameter, or a tuple only including these types, but got {base_type}") @constexpr @@ -441,7 +443,7 @@ def check_view_shape(x): return x -# convert noraml param_check functions to constexpr functions +# convert normal param_check functions to constexpr functions check_astype_dtype_const = constexpr(validator.check_astype_dtype) check_transpose_axis_const = constexpr(validator.check_transpose_axis) check_reshape_shp_const = constexpr(validator.check_reshape_shp) @@ -449,8 +451,9 @@ check_flatten_order_const = constexpr(validator.check_flatten_order) check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) + def tensor_bool(x): - """tensor as conditon, if is constant, return immediate bool value""" + """tensor as condition, if is constant, return immediate bool value""" is_cond = check_is_tensor_bool_cond(F.shape(x)) if is_cond and F.isconstant(x): return const_tensor_to_bool(x) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index e498e730ad..89a6bd30ab 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -382,7 +382,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { .def(py::init<>()); })); -FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { +FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const { MS_EXCEPTION_IF_NULL(sequeue); FuncGraphPtr ret = std::make_shared(); diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 981efcd3e2..91a7c1b8b2 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -104,7 +104,7 @@ class Tail : public MetaFuncGraph { MS_DECLARE_PARENT(Tail, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); + FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const; friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index fab20a82ad..780636c601 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -156,9 +156,12 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, } bool IsAllFuncInValueSequence(const std::vector &value_vec) { + if (value_vec.empty()) { + return false; + } for (auto &elem : value_vec) { if (elem->isa() || elem->isa()) { - const auto &vec = GetValue>(elem); + const auto &vec = GetValue(elem); auto is_graph = IsAllFuncInValueSequence(vec); if (!is_graph) { return false; @@ -194,20 +197,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F return cnode; } -// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node +// transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); - const auto &value_vec = GetValue>(value_node->value()); + const auto &value_vec = GetValue(value_node->value()); if (!IsAllFuncInValueSequence(value_vec)) { return false; } // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // So if has graph in list, try to replace the node with make tuple of graph value node. - // we do this because the graphmanger won't investigate the graph inside valuetuple, + // we do this because the graph manager won't investigate the graph inside valuetuple, // change the vector of graph to be make_tuple of graph value node. - // (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all + // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all // independent nodes. auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); // replace the ret ptr to be make tuple of graph value node diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 6e8c299834..56e58c2f9d 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -69,10 +69,6 @@ namespace pipeline { using Tensor = mindspore::tensor::Tensor; using MetaTensor = mindspore::tensor::MetaTensor; using TensorOrderMap = std::map>; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractDictionaryPtr; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensorPtr; using mindspore::abstract::AbstractTuple; @@ -103,6 +99,33 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { return abstract::FromValue(value, broaden); } +bool CheckArgValid(const py::handle &arg) { + if (py::isinstance(arg) || py::isinstance(arg)) { + auto vector_arg = py::cast(arg); + return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid); + } + + if (py::isinstance(arg)) { + auto dict_arg = py::cast(arg); + return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); + } + + return py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || + (py::isinstance(arg) && !py::hasattr(arg, "__parameter__")); +} + +void CheckArgsValid(const py::tuple &args) { + for (size_t i = 0; i < args.size(); i++) { + if (!CheckArgValid(args[i])) { + MS_EXCEPTION(TypeError) + << "The inputs types of the outermost network support bool, int, float, tensor, " + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " + "and tuple or list containing only these types, and dict whose values are these types, but got " + << i << "th arg is " << py::str(args[i]); + } + } +} + std::string GetCompileExceptionInfo() { std::ostringstream oss; trace::TraceGraphEval(); @@ -470,11 +493,13 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons MS_LOG(ERROR) << "Arg phase must be string."; return false; } - // check the arg valid? + // check the function or net is valid if (py::isinstance(obj)) { MS_LOG(ERROR) << "Find error: parse obj is None."; return false; } + // check the args of function or net is valid + CheckArgsValid(args); #ifdef ENABLE_GE GetGeBackendPolicy(); #endif diff --git a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc index 1f139cdd27..7c0bc4366b 100644 --- a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc @@ -52,7 +52,7 @@ REGISTER_PYBIND_DEFINE( return t->DeepCopy(); }); (void)py::class_>(m_sub, "Number").def(py::init()); - (void)py::class_>(m_sub, "Bool") + (void)py::class_>(m_sub, "Bool") .def(py::init()) .def(py::pickle( [](const Bool &) { // __getstate__ @@ -61,7 +61,7 @@ REGISTER_PYBIND_DEFINE( [](const py::tuple &) { // __setstate__ return std::make_shared(); })); - (void)py::class_>(m_sub, "Int") + (void)py::class_>(m_sub, "Int") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( @@ -77,7 +77,7 @@ REGISTER_PYBIND_DEFINE( Int data(t[0].cast()); return data; })); - (void)py::class_>(m_sub, "UInt") + (void)py::class_>(m_sub, "UInt") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( @@ -93,7 +93,7 @@ REGISTER_PYBIND_DEFINE( UInt data(t[0].cast()); return data; })); - (void)py::class_>(m_sub, "Float") + (void)py::class_>(m_sub, "Float") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( diff --git a/tests/ut/python/pipeline/parse/test_isinstance.py b/tests/ut/python/pipeline/parse/test_isinstance.py index 4507293bb0..9e54be0a24 100644 --- a/tests/ut/python/pipeline/parse/test_isinstance.py +++ b/tests/ut/python/pipeline/parse/test_isinstance.py @@ -20,7 +20,7 @@ import mindspore.nn as nn from mindspore import Tensor, Parameter from mindspore import context -context.set_context(mode=context.GRAPH_MODE, save_graphs=True) +context.set_context(mode=context.GRAPH_MODE) def test_isinstance(): @@ -35,6 +35,7 @@ def test_isinstance(): self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) self.list_member = list(self.tuple_member) self.weight = Parameter(1.0) + self.empty_list = [] def construct(self, x, y): is_int = isinstance(self.int_member, int) @@ -54,7 +55,9 @@ def test_isinstance(): bool_is_string = isinstance(self.bool_member, str) tensor_is_tuple = isinstance(x, tuple) tuple_is_list = isinstance(self.tuple_member, list) - return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \ + is_empty_list = isinstance(self.empty_list, list) + return is_int, is_float, is_bool, is_string, \ + is_empty_list, is_parameter, is_tensor_const, is_tensor_var, \ is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list @@ -62,7 +65,7 @@ def test_isinstance(): net = Net() x = Tensor(np.arange(4)) y = Tensor(np.arange(5)) - assert net(x, y) == (True,) * 13 + (False,) * 4 + assert net(x, y) == (True,) * 14 + (False,) * 4 def test_isinstance_not_supported(): @@ -77,7 +80,8 @@ def test_isinstance_not_supported(): net = Net() with pytest.raises(TypeError) as err: net() - assert "The type 'None' is not supported for 'isinstance'" in str(err.value) + assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ + "or a tuple only including these types, but got None" in str(err.value) def test_isinstance_second_arg_is_list(): diff --git a/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py index 35c4c64486..3267b3fa90 100644 --- a/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_ms_function_pass_non_tensor_inputs.py @@ -15,7 +15,7 @@ """ test ms_function pass non_tensor inputs""" import numpy as np -from mindspore import Tensor, ms_function, Parameter +from mindspore import Tensor, ms_function from mindspore import context from mindspore.ops import operations as P @@ -56,5 +56,5 @@ def tensor_reduce(tensor_x, axis, tensor_y): def test_tensor_reduce(): tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) axis = (0, 1) - tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) + tensor_y = Tensor(np.ones((4, 5), np.float32) * 2) tensor_reduce(tensor_x, axis, tensor_y) diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index d6bc8167f0..aaede8a5a8 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -14,47 +14,110 @@ # ============================================================================ """ test outermost net pass non_tensor inputs""" import numpy as np +import pytest import mindspore.nn as nn -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore import context from mindspore.ops import composite as C context.set_context(mode=context.GRAPH_MODE) -def test_outermost_net_pass_scalar_tuple_list_dict(): - class TestNet(nn.Cell): - def __init__(self): - super(TestNet, self).__init__() - - def construct(self, tuple_a, z, list_m, w, s, dict_n): - return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] - - class GradNet(nn.Cell): - def __init__(self, net): - super(GradNet, self).__init__() - self.forward_net = net - self.sens = Tensor(np.ones((2, 2), np.float32) * 5) - self.grad_all = C.GradOperation(get_all=True) - - def construct(self, tuple_a, z, list_m, w, s, dict_n): - return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) - - x = Tensor(np.ones((2, 2), np.float32)) - y = Tensor(np.ones((2, 2), np.float32) * 2) - z = Tensor(np.ones((2, 2), np.float32) * 3) - w = Tensor(np.ones((2, 2), np.float32) * 4) - arg_t0 = (x, y, z, w) - arg_t1 = (w, y, z, w) - arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] - arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] - args_d0 = {"x": x, "y": y} - args_d1 = {"x": x, "y": y} - forward_net = TestNet() - forward_net(arg_t0, z, arg_l0, w, 6, args_d0) - forward_net(arg_t1, z, arg_l1, x, 6, args_d1) - - grad_net = GradNet(forward_net) - grad_net(arg_t0, z, arg_l0, w, 6, args_d0) - grad_net(arg_t1, z, arg_l1, x, 6, args_d1) +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): + if flag: + return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"] + return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"] + + +class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=True) + + def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): + return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag) + + +x = Tensor(np.ones((2, 2), np.float32)) +y = Tensor(np.ones((2, 2), np.float32) * 2) +z = Tensor(np.ones((2, 2), np.float32) * 3) +w = Tensor(np.ones((2, 2), np.float32) * 4) +sl = 6 +s = "ok" +arg_t0 = (x, y, z, w) +arg_t1 = (w, y, z, w) +arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] +arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] +args_d0 = {"x": x, "y": y} +args_d1 = {"x": x, "y": y} +flag_0 = True +flag_1 = False + + +p = Parameter(x, name="weight") +a = np.ones((2, 2)) + +forward_net = Net() +grad_net = GradNet(forward_net) + + +def test_outermost_net_inputs_including_non_tensor(): + forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) + forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) + + +def test_grad_net_inputs_including_non_tensor(): + grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) + grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) + + +def test_net_inputs_including_str(): + with pytest.raises(TypeError) as err: + grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) + assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ + "and tuple or list containing only these types, and dict whose values are these types, " \ + "but got 1th arg is ok" in str(err.value) + + +def test_outermost_net_pass_parameter(): + with pytest.raises(TypeError) as err: + forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0) + assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ + "and tuple or list containing only these types, and dict whose values are these types, " \ + "but got 1th arg is Parameter (name=weight)" in str(err.value) + + +def test_outermost_net_pass_tuple_including_parameter(): + with pytest.raises(TypeError) as err: + forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p)) + assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ + "and tuple or list containing only these types, and dict whose values are these types, " \ + "but got 6th arg is (" in str(err.value) + + +def test_outermost_net_pass_list_including_parameter(): + with pytest.raises(TypeError) as err: + forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0) + assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ + "and tuple or list containing only these types, and dict whose values are these types, " \ + "but got 4th arg is [" in str(err.value) + + +def test_grad_net_pass_dict_including_parameter(): + with pytest.raises(TypeError) as err: + grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) + assert "The inputs types of the outermost network support bool, int, float, tensor, " \ + "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ + "and tuple or list containing only these types, and dict whose values are these types, " \ + "but got 3th arg is {" in str(err.value)