| @@ -42,6 +42,7 @@ itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, | |||
| mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, | |||
| mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} | |||
| def mean(x, axis=(), keep_dims=False): | |||
| """ | |||
| Reduces a dimension of a tensor by averaging all elements in the dimension. | |||
| @@ -218,11 +219,11 @@ def swapaxes(x, axis1, axis2): | |||
| perm = F.make_range(0, x.ndim) | |||
| new_perm = None | |||
| if axis2 + 1 < x.ndim: | |||
| new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ | |||
| perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] | |||
| new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ | |||
| perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] | |||
| else: | |||
| new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ | |||
| perm[axis1+1:axis2] + perm[axis1:axis1+1] | |||
| new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ | |||
| perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] | |||
| return F.transpose(x, new_perm) | |||
| @@ -343,7 +344,7 @@ def isinstance_(x, base_type): | |||
| def while_cond(x): | |||
| """For while condtion, if the condition is a tensor, the loop will not be unrolled""" | |||
| """For while condition, if the condition is a tensor, the loop will not be unrolled""" | |||
| if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): | |||
| is_cond = check_is_tensor_bool_cond(F.shape(x)) | |||
| if is_cond: | |||
| @@ -373,7 +374,8 @@ def check_type_same(x_type, base_type): | |||
| target_type = pytype_to_mstype[base_type] | |||
| return isinstance(x_type, target_type) | |||
| except KeyError: | |||
| raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") | |||
| raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " | |||
| f"Tensor, Parameter, or a tuple only including these types, but got {base_type}") | |||
| @constexpr | |||
| @@ -441,7 +443,7 @@ def check_view_shape(x): | |||
| return x | |||
| # convert noraml param_check functions to constexpr functions | |||
| # convert normal param_check functions to constexpr functions | |||
| check_astype_dtype_const = constexpr(validator.check_astype_dtype) | |||
| check_transpose_axis_const = constexpr(validator.check_transpose_axis) | |||
| check_reshape_shp_const = constexpr(validator.check_reshape_shp) | |||
| @@ -449,8 +451,9 @@ check_flatten_order_const = constexpr(validator.check_flatten_order) | |||
| check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) | |||
| prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) | |||
| def tensor_bool(x): | |||
| """tensor as conditon, if is constant, return immediate bool value""" | |||
| """tensor as condition, if is constant, return immediate bool value""" | |||
| is_cond = check_is_tensor_bool_cond(F.shape(x)) | |||
| if is_cond and F.isconstant(x): | |||
| return const_tensor_to_bool(x) | |||
| @@ -382,7 +382,7 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { | |||
| .def(py::init<>()); | |||
| })); | |||
| FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) { | |||
| FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const { | |||
| MS_EXCEPTION_IF_NULL(sequeue); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| @@ -104,7 +104,7 @@ class Tail : public MetaFuncGraph { | |||
| MS_DECLARE_PARENT(Tail, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue); | |||
| FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const; | |||
| friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | |||
| @@ -156,9 +156,12 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, | |||
| } | |||
| bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) { | |||
| if (value_vec.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &elem : value_vec) { | |||
| if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) { | |||
| const auto &vec = GetValue<std::vector<ValuePtr>>(elem); | |||
| const auto &vec = GetValue<ValuePtrList>(elem); | |||
| auto is_graph = IsAllFuncInValueSequence(vec); | |||
| if (!is_graph) { | |||
| return false; | |||
| @@ -194,20 +197,20 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F | |||
| return cnode; | |||
| } | |||
| // transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node | |||
| // transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node | |||
| bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, | |||
| const ValueNodePtr &value_node, AnfNodePtr *const transformed) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value()); | |||
| const auto &value_vec = GetValue<ValuePtrList>(value_node->value()); | |||
| if (!IsAllFuncInValueSequence(value_vec)) { | |||
| return false; | |||
| } | |||
| // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, | |||
| // So if has graph in list, try to replace the node with make tuple of graph value node. | |||
| // we do this because the graphmanger won't investigate the graph inside valuetuple, | |||
| // we do this because the graph manager won't investigate the graph inside valuetuple, | |||
| // change the vector of graph to be make_tuple of graph value node. | |||
| // (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all | |||
| // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all | |||
| // independent nodes. | |||
| auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); | |||
| // replace the ret ptr to be make tuple of graph value node | |||
| @@ -69,10 +69,6 @@ namespace pipeline { | |||
| using Tensor = mindspore::tensor::Tensor; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; | |||
| using mindspore::abstract::AbstractDictionary; | |||
| using mindspore::abstract::AbstractDictionaryPtr; | |||
| using mindspore::abstract::AbstractList; | |||
| using mindspore::abstract::AbstractListPtr; | |||
| using mindspore::abstract::AbstractTensor; | |||
| using mindspore::abstract::AbstractTensorPtr; | |||
| using mindspore::abstract::AbstractTuple; | |||
| @@ -103,6 +99,33 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) { | |||
| return abstract::FromValue(value, broaden); | |||
| } | |||
| bool CheckArgValid(const py::handle &arg) { | |||
| if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) { | |||
| auto vector_arg = py::cast<py::list>(arg); | |||
| return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid); | |||
| } | |||
| if (py::isinstance<py::dict>(arg)) { | |||
| auto dict_arg = py::cast<py::dict>(arg); | |||
| return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); | |||
| } | |||
| return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<Number>(arg) || | |||
| (py::isinstance<Tensor>(arg) && !py::hasattr(arg, "__parameter__")); | |||
| } | |||
| void CheckArgsValid(const py::tuple &args) { | |||
| for (size_t i = 0; i < args.size(); i++) { | |||
| if (!CheckArgValid(args[i])) { | |||
| MS_EXCEPTION(TypeError) | |||
| << "The inputs types of the outermost network support bool, int, float, tensor, " | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " | |||
| "and tuple or list containing only these types, and dict whose values are these types, but got " | |||
| << i << "th arg is " << py::str(args[i]); | |||
| } | |||
| } | |||
| } | |||
| std::string GetCompileExceptionInfo() { | |||
| std::ostringstream oss; | |||
| trace::TraceGraphEval(); | |||
| @@ -470,11 +493,13 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||
| MS_LOG(ERROR) << "Arg phase must be string."; | |||
| return false; | |||
| } | |||
| // check the arg valid? | |||
| // check the function or net is valid | |||
| if (py::isinstance<py::none>(obj)) { | |||
| MS_LOG(ERROR) << "Find error: parse obj is None."; | |||
| return false; | |||
| } | |||
| // check the args of function or net is valid | |||
| CheckArgsValid(args); | |||
| #ifdef ENABLE_GE | |||
| GetGeBackendPolicy(); | |||
| #endif | |||
| @@ -52,7 +52,7 @@ REGISTER_PYBIND_DEFINE( | |||
| return t->DeepCopy(); | |||
| }); | |||
| (void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init()); | |||
| (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | |||
| (void)py::class_<Bool, Number, std::shared_ptr<Bool>>(m_sub, "Bool") | |||
| .def(py::init()) | |||
| .def(py::pickle( | |||
| [](const Bool &) { // __getstate__ | |||
| @@ -61,7 +61,7 @@ REGISTER_PYBIND_DEFINE( | |||
| [](const py::tuple &) { // __setstate__ | |||
| return std::make_shared<Bool>(); | |||
| })); | |||
| (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | |||
| (void)py::class_<Int, Number, std::shared_ptr<Int>>(m_sub, "Int") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| @@ -77,7 +77,7 @@ REGISTER_PYBIND_DEFINE( | |||
| Int data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt") | |||
| (void)py::class_<UInt, Number, std::shared_ptr<UInt>>(m_sub, "UInt") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| @@ -93,7 +93,7 @@ REGISTER_PYBIND_DEFINE( | |||
| UInt data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float") | |||
| (void)py::class_<Float, Number, std::shared_ptr<Float>>(m_sub, "Float") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| @@ -20,7 +20,7 @@ import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_isinstance(): | |||
| @@ -35,6 +35,7 @@ def test_isinstance(): | |||
| self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) | |||
| self.list_member = list(self.tuple_member) | |||
| self.weight = Parameter(1.0) | |||
| self.empty_list = [] | |||
| def construct(self, x, y): | |||
| is_int = isinstance(self.int_member, int) | |||
| @@ -54,7 +55,9 @@ def test_isinstance(): | |||
| bool_is_string = isinstance(self.bool_member, str) | |||
| tensor_is_tuple = isinstance(x, tuple) | |||
| tuple_is_list = isinstance(self.tuple_member, list) | |||
| return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \ | |||
| is_empty_list = isinstance(self.empty_list, list) | |||
| return is_int, is_float, is_bool, is_string, \ | |||
| is_empty_list, is_parameter, is_tensor_const, is_tensor_var, \ | |||
| is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ | |||
| is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ | |||
| float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list | |||
| @@ -62,7 +65,7 @@ def test_isinstance(): | |||
| net = Net() | |||
| x = Tensor(np.arange(4)) | |||
| y = Tensor(np.arange(5)) | |||
| assert net(x, y) == (True,) * 13 + (False,) * 4 | |||
| assert net(x, y) == (True,) * 14 + (False,) * 4 | |||
| def test_isinstance_not_supported(): | |||
| @@ -77,7 +80,8 @@ def test_isinstance_not_supported(): | |||
| net = Net() | |||
| with pytest.raises(TypeError) as err: | |||
| net() | |||
| assert "The type 'None' is not supported for 'isinstance'" in str(err.value) | |||
| assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ | |||
| "or a tuple only including these types, but got None" in str(err.value) | |||
| def test_isinstance_second_arg_is_list(): | |||
| @@ -15,7 +15,7 @@ | |||
| """ test ms_function pass non_tensor inputs""" | |||
| import numpy as np | |||
| from mindspore import Tensor, ms_function, Parameter | |||
| from mindspore import Tensor, ms_function | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| @@ -56,5 +56,5 @@ def tensor_reduce(tensor_x, axis, tensor_y): | |||
| def test_tensor_reduce(): | |||
| tensor_x = Tensor(np.ones((2, 3, 4, 5), np.float32)) | |||
| axis = (0, 1) | |||
| tensor_y = Parameter(Tensor(np.ones((4, 5), np.float32) * 2)) | |||
| tensor_y = Tensor(np.ones((4, 5), np.float32) * 2) | |||
| tensor_reduce(tensor_x, axis, tensor_y) | |||
| @@ -14,47 +14,110 @@ | |||
| # ============================================================================ | |||
| """ test outermost net pass non_tensor inputs""" | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_outermost_net_pass_scalar_tuple_list_dict(): | |||
| class TestNet(nn.Cell): | |||
| def __init__(self): | |||
| super(TestNet, self).__init__() | |||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.forward_net = net | |||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||
| self.grad_all = C.GradOperation(get_all=True) | |||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||
| return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) | |||
| x = Tensor(np.ones((2, 2), np.float32)) | |||
| y = Tensor(np.ones((2, 2), np.float32) * 2) | |||
| z = Tensor(np.ones((2, 2), np.float32) * 3) | |||
| w = Tensor(np.ones((2, 2), np.float32) * 4) | |||
| arg_t0 = (x, y, z, w) | |||
| arg_t1 = (w, y, z, w) | |||
| arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| args_d0 = {"x": x, "y": y} | |||
| args_d1 = {"x": x, "y": y} | |||
| forward_net = TestNet() | |||
| forward_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||
| grad_net = GradNet(forward_net) | |||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| grad_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): | |||
| if flag: | |||
| return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"] | |||
| return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"] | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.forward_net = net | |||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||
| self.grad_all = C.GradOperation(get_all=True) | |||
| def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag): | |||
| return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag) | |||
| x = Tensor(np.ones((2, 2), np.float32)) | |||
| y = Tensor(np.ones((2, 2), np.float32) * 2) | |||
| z = Tensor(np.ones((2, 2), np.float32) * 3) | |||
| w = Tensor(np.ones((2, 2), np.float32) * 4) | |||
| sl = 6 | |||
| s = "ok" | |||
| arg_t0 = (x, y, z, w) | |||
| arg_t1 = (w, y, z, w) | |||
| arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| args_d0 = {"x": x, "y": y} | |||
| args_d1 = {"x": x, "y": y} | |||
| flag_0 = True | |||
| flag_1 = False | |||
| p = Parameter(x, name="weight") | |||
| a = np.ones((2, 2)) | |||
| forward_net = Net() | |||
| grad_net = GradNet(forward_net) | |||
| def test_outermost_net_inputs_including_non_tensor(): | |||
| forward_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) | |||
| forward_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) | |||
| def test_grad_net_inputs_including_non_tensor(): | |||
| grad_net(arg_t0, z, arg_l0, w, sl, args_d0, flag_0) | |||
| grad_net(arg_t1, z, arg_l1, x, sl, args_d1, flag_1) | |||
| def test_net_inputs_including_str(): | |||
| with pytest.raises(TypeError) as err: | |||
| grad_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0) | |||
| assert "The inputs types of the outermost network support bool, int, float, tensor, " \ | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ | |||
| "and tuple or list containing only these types, and dict whose values are these types, " \ | |||
| "but got 1th arg is ok" in str(err.value) | |||
| def test_outermost_net_pass_parameter(): | |||
| with pytest.raises(TypeError) as err: | |||
| forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0) | |||
| assert "The inputs types of the outermost network support bool, int, float, tensor, " \ | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ | |||
| "and tuple or list containing only these types, and dict whose values are these types, " \ | |||
| "but got 1th arg is Parameter (name=weight)" in str(err.value) | |||
| def test_outermost_net_pass_tuple_including_parameter(): | |||
| with pytest.raises(TypeError) as err: | |||
| forward_net(arg_t0, z, arg_l0, sl, args_d0, flag_0, (z, w, p)) | |||
| assert "The inputs types of the outermost network support bool, int, float, tensor, " \ | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ | |||
| "and tuple or list containing only these types, and dict whose values are these types, " \ | |||
| "but got 6th arg is (" in str(err.value) | |||
| def test_outermost_net_pass_list_including_parameter(): | |||
| with pytest.raises(TypeError) as err: | |||
| forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0) | |||
| assert "The inputs types of the outermost network support bool, int, float, tensor, " \ | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ | |||
| "and tuple or list containing only these types, and dict whose values are these types, " \ | |||
| "but got 4th arg is [" in str(err.value) | |||
| def test_grad_net_pass_dict_including_parameter(): | |||
| with pytest.raises(TypeError) as err: | |||
| grad_net(arg_t0, z, arg_l0, {"x": z, "y": w, "z": p}, sl, args_d0, flag_0) | |||
| assert "The inputs types of the outermost network support bool, int, float, tensor, " \ | |||
| "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ | |||
| "and tuple or list containing only these types, and dict whose values are these types, " \ | |||
| "but got 3th arg is {" in str(err.value) | |||