| @@ -88,6 +88,17 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { | |||||
| oss << stage_idx << "_" << action_name; | oss << stage_idx << "_" << action_name; | ||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto tensor_arg = arg->cast<TensorPtr>(); | |||||
| if (tensor_arg == nullptr) { | |||||
| MS_EXCEPTION(TypeError) << "For 'graph mode', the " << idx << "th arg: " << arg->ToString() << " is not a tensor."; | |||||
| } | |||||
| if (tensor_arg->is_parameter()) { | |||||
| MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) { | py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) { | ||||
| @@ -460,6 +471,9 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||||
| if (!succ) { | if (!succ) { | ||||
| MS_LOG(EXCEPTION) << "Args convert error"; | MS_LOG(EXCEPTION) << "Args convert error"; | ||||
| } | } | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||||
| CheckArgIsTensor(converted, i); | |||||
| } | |||||
| bool broaden = true; | bool broaden = true; | ||||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | args_spec.push_back(abstract::FromValue(converted, broaden)); | ||||
| } | } | ||||
| @@ -701,15 +715,6 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef | |||||
| if (!succ) { | if (!succ) { | ||||
| MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; | ||||
| } | } | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == 0) { | |||||
| if (!converted->isa<tensor::Tensor>()) { | |||||
| MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() | |||||
| << " is not tensor."; | |||||
| } | |||||
| if (converted->cast<TensorPtr>()->is_parameter()) { | |||||
| MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; | |||||
| } | |||||
| } | |||||
| arg_list->push_back(converted); | arg_list->push_back(converted); | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ | |||||
| """ Test Dynamic Learning Rate """ | """ Test Dynamic Learning Rate """ | ||||
| import pytest | import pytest | ||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import learning_rate_schedule as lr_schedules | from mindspore.nn import learning_rate_schedule as lr_schedules | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| @@ -29,7 +29,7 @@ warmup_steps = 2 | |||||
| min_lr = 0.01 | min_lr = 0.01 | ||||
| max_lr = 0.1 | max_lr = 0.1 | ||||
| power = 0.5 | power = 0.5 | ||||
| global_step = Parameter(Tensor(2, mstype.int32), 'global_step') | |||||
| global_step = Tensor(2, mstype.int32) | |||||
| class TestInit: | class TestInit: | ||||
| @@ -104,7 +104,6 @@ def test_pow(): | |||||
| result = testpow(input_tensor, power) | result = testpow(input_tensor, power) | ||||
| assert np.all(result.asnumpy() == expect) | assert np.all(result.asnumpy() == expect) | ||||
| net = PowNet() | net = PowNet() | ||||
| net(input_tensor, True) | |||||
| net(input_tensor, power2) | net(input_tensor, power2) | ||||
| @@ -85,6 +85,33 @@ class NetForConcat1(nn.Cell): | |||||
| return self.concat((x1, x2)) | return self.concat((x1, x2)) | ||||
| class NetForConcat2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetForConcat2, self).__init__() | |||||
| self.concat = P.Concat(axis=2) | |||||
| def construct(self, x1, x2): | |||||
| return self.concat((x1, x2)) | |||||
| class NetForConcat3(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetForConcat3, self).__init__() | |||||
| self.concat = P.Concat(axis=0) | |||||
| def construct(self, x1, x2, x3): | |||||
| return self.concat((x1, x2, x3)) | |||||
| class NetForConcat4(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetForConcat4, self).__init__() | |||||
| self.concat = P.Concat(axis=-1) | |||||
| def construct(self, x1, x2, x3): | |||||
| return self.concat((x1, x2, x3)) | |||||
| class NetForPackInput(nn.Cell): | class NetForPackInput(nn.Cell): | ||||
| def __init__(self, op): | def __init__(self, op): | ||||
| super(NetForPackInput, self).__init__() | super(NetForPackInput, self).__init__() | ||||
| @@ -1080,7 +1107,7 @@ test_case_math_ops = [ | |||||
| 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), | 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), | ||||
| ('NotEqual_0', { | ('NotEqual_0', { | ||||
| 'block': P.NotEqual(), | 'block': P.NotEqual(), | ||||
| 'desc_inputs': [1, [2, 3, 4, 5]], | |||||
| 'desc_inputs': [Tensor(np.array(1).astype(np.int32)), [2, 3, 4, 5]], | |||||
| 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], | 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ApproximateEqual', { | ('ApproximateEqual', { | ||||
| @@ -1893,15 +1920,15 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))], | 'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ConcatV2_0', { | ('ConcatV2_0', { | ||||
| 'block': P.Concat(), | |||||
| 'block': NetForConcat1(), | |||||
| 'desc_inputs': [ | 'desc_inputs': [ | ||||
| (Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), | |||||
| Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))], | |||||
| Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), | |||||
| Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))], | |||||
| 'desc_bprop': [([4, 2], {'dtype': np.int32})]}), | 'desc_bprop': [([4, 2], {'dtype': np.int32})]}), | ||||
| ('ConcatV2_1', { | ('ConcatV2_1', { | ||||
| 'block': P.Concat(axis=2), | |||||
| 'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), | |||||
| Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))], | |||||
| 'block': NetForConcat2(), | |||||
| 'desc_inputs': [Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), | |||||
| Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32))], | |||||
| 'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}), | 'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}), | ||||
| ('ConcatV2_2', { | ('ConcatV2_2', { | ||||
| 'block': NetForConcat(), | 'block': NetForConcat(), | ||||
| @@ -1912,17 +1939,17 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [[2, 2], [2, 2]], | 'desc_inputs': [[2, 2], [2, 2]], | ||||
| 'desc_bprop': [[4, 2]]}), | 'desc_bprop': [[4, 2]]}), | ||||
| ('ConcatV2_4', { | ('ConcatV2_4', { | ||||
| 'block': P.Concat(axis=0), | |||||
| 'block': NetForConcat3(), | |||||
| 'desc_inputs': [ | 'desc_inputs': [ | ||||
| (Tensor(np.ones((3, 2, 3), np.float32)), | |||||
| Tensor(np.ones((5, 2, 3), np.float32)), | |||||
| Tensor(np.ones((6, 2, 3), np.float32)))], | |||||
| Tensor(np.ones((3, 2, 3), np.float32)), | |||||
| Tensor(np.ones((5, 2, 3), np.float32)), | |||||
| Tensor(np.ones((6, 2, 3), np.float32))], | |||||
| 'desc_bprop': [[14, 2, 3]]}), | 'desc_bprop': [[14, 2, 3]]}), | ||||
| ('ConcatV2_5', { | ('ConcatV2_5', { | ||||
| 'block': P.Concat(axis=-1), | |||||
| 'desc_inputs': [(Tensor(np.array([1], np.float32)), | |||||
| Tensor(np.array([1], np.float32)), | |||||
| Tensor(np.array([1], np.float32)))], | |||||
| 'block': NetForConcat4(), | |||||
| 'desc_inputs': [Tensor(np.array([1], np.float32)), | |||||
| Tensor(np.array([1], np.float32)), | |||||
| Tensor(np.array([1], np.float32))], | |||||
| 'desc_bprop': [[3, ]]}), | 'desc_bprop': [[3, ]]}), | ||||
| ('Pack_0', { | ('Pack_0', { | ||||
| 'block': NetForPackInput(P.Pack()), | 'block': NetForPackInput(P.Pack()), | ||||
| @@ -74,7 +74,7 @@ def test_remove_and_fv_2(): | |||||
| return ret | return ret | ||||
| @ms_function | @ms_function | ||||
| def out_loop(input1, input_data): | |||||
| def out_loop(input1, input_data0, input_data1): | |||||
| ret = () | ret = () | ||||
| def fv_func1(y): | def fv_func1(y): | ||||
| @@ -82,14 +82,15 @@ def test_remove_and_fv_2(): | |||||
| def fv_func2(y): | def fv_func2(y): | ||||
| return input1 - y | return input1 - y | ||||
| fv_func_list = [fv_func1, fv_func2] | fv_func_list = [fv_func1, fv_func2] | ||||
| ele0 = inner_loop(input1, input_data[0], fv_func_list) | |||||
| ele1 = inner_loop(input1, input_data[1], fv_func_list) | |||||
| ele0 = inner_loop(input1, input_data0, fv_func_list) | |||||
| ele1 = inner_loop(input1, input_data1, fv_func_list) | |||||
| ret = (ele0, ele1) | ret = (ele0, ele1) | ||||
| return ret | return ret | ||||
| input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1)))) | |||||
| input_data0 = Tensor(normal(0, 0.1, (3, 3))) | |||||
| input_data1 = Tensor(normal(0, 0.1, (3, 1))) | |||||
| input1 = Tensor(normal(0, 0.1, (3, 3))) | input1 = Tensor(normal(0, 0.1, (3, 3))) | ||||
| out_loop(input1, input_data) | |||||
| out_loop(input1, input_data0, input_data1) | |||||
| # test cell as high order argument | # test cell as high order argument | ||||
| @@ -466,7 +466,7 @@ def test_tensor_assign(): | |||||
| # Error for A[Slice] = Number | # Error for A[Slice] = Number | ||||
| # 1. A[Slice] = Number, Slice error | # 1. A[Slice] = Number, Slice error | ||||
| with pytest.raises(IndexError): | with pytest.raises(IndexError): | ||||
| net_e2(t, 2) | |||||
| net_e2(t, Tensor(2, mstype.int32)) | |||||
| # Error for A[Slice] = U, U is a Tensor | # Error for A[Slice] = U, U is a Tensor | ||||
| # 1. A[Slice] = U, u.size is error | # 1. A[Slice] = U, u.size is error | ||||
| @@ -493,7 +493,7 @@ def test_tensor_assign(): | |||||
| # Error for A[Tuple(Slice...)] = Number | # Error for A[Tuple(Slice...)] = Number | ||||
| # 1. A[Tuple(Slice...)] = Number, Slice error | # 1. A[Tuple(Slice...)] = Number, Slice error | ||||
| with pytest.raises(IndexError): | with pytest.raises(IndexError): | ||||
| net_e1(Ta, 2) | |||||
| net_e1(Ta, Tensor(2, mstype.int32)) | |||||
| net = TensorAssignWithInteger() | net = TensorAssignWithInteger() | ||||
| # Error for A[Number] = scalar/Tensor | # Error for A[Number] = scalar/Tensor | ||||
| @@ -675,12 +675,12 @@ def test_tensor_assign_bool_index(): | |||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| net3(Ta, Tb, Tc, u_tensor) | net3(Ta, Tb, Tc, u_tensor) | ||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| net3(Ta, Tb, Tc, u_scalar) | |||||
| net3(Ta, Tb, Tc, Tensor(u_scalar, mstype.int32)) | |||||
| net4 = TensorAssignWithBoolTensorIndex2Error() | net4 = TensorAssignWithBoolTensorIndex2Error() | ||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| net4(Ta, u_tensor) | net4(Ta, u_tensor) | ||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| net4(Ta, u_scalar) | |||||
| net4(Ta, Tensor(u_scalar, mstype.int32)) | |||||
| test_cases = [ | test_cases = [ | ||||
| @@ -32,7 +32,8 @@ class NetWork_1(Cell): | |||||
| super(NetWork_1, self).__init__() | super(NetWork_1, self).__init__() | ||||
| self.addN = P.AddN() | self.addN = P.AddN() | ||||
| def construct(self, tensor_tuple): | |||||
| def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): | |||||
| tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6) | |||||
| tensor_tuple_slice0 = tensor_tuple[:] | tensor_tuple_slice0 = tensor_tuple[:] | ||||
| tensor_tuple_slice1 = tensor_tuple[:3] | tensor_tuple_slice1 = tensor_tuple[:3] | ||||
| tensor_tuple_slice2 = tensor_tuple[1:] | tensor_tuple_slice2 = tensor_tuple[1:] | ||||
| @@ -52,7 +53,8 @@ class NetWork_2(Cell): | |||||
| super(NetWork_2, self).__init__() | super(NetWork_2, self).__init__() | ||||
| self.addN = P.AddN() | self.addN = P.AddN() | ||||
| def construct(self, tensor_tuple): | |||||
| def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): | |||||
| tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6) | |||||
| tensor_tuple_slice0 = tensor_tuple[::-1] | tensor_tuple_slice0 = tensor_tuple[::-1] | ||||
| tensor_tuple_slice1 = tensor_tuple[-1::-1] | tensor_tuple_slice1 = tensor_tuple[-1::-1] | ||||
| tensor_tuple_slice2 = tensor_tuple[:-4:-1] | tensor_tuple_slice2 = tensor_tuple[:-4:-1] | ||||
| @@ -94,21 +96,21 @@ class NetWorkOutOfBounds(Cell): | |||||
| test_cases = [ | test_cases = [ | ||||
| ('SlicePositive', { | ('SlicePositive', { | ||||
| 'block': NetWork_1(), | 'block': NetWork_1(), | ||||
| 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)))], | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32))], | |||||
| }), | }), | ||||
| ('SliceNegative', { | ('SliceNegative', { | ||||
| 'block': NetWork_2(), | 'block': NetWork_2(), | ||||
| 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)))], | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32)), | |||||
| Tensor(np.zeros([2, 3, 4], np.int32)), | |||||
| Tensor(np.ones([2, 3, 4], np.int32))], | |||||
| }), | }), | ||||
| ] | ] | ||||
| @@ -98,7 +98,7 @@ def test_dup_context(): | |||||
| return net1() + net2() | return net1() + net2() | ||||
| Net()(5.0) | |||||
| Net()(Tensor(np.array(5.0).astype(np.float32))) | |||||
| def test_maybe_poly_func(): | def test_maybe_poly_func(): | ||||
| @@ -125,4 +125,4 @@ def test_maybe_poly_func(): | |||||
| y_input = Tensor(np.array([1, 2]).astype(np.int32)) | y_input = Tensor(np.array([1, 2]).astype(np.int32)) | ||||
| z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) | z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) | ||||
| Net()(1, y_input, z_input) | |||||
| Net()(Tensor(np.array(1).astype(np.int32)), y_input, z_input) | |||||
| @@ -192,10 +192,10 @@ def test_enumerate_start_type_error(): | |||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| return enumerate(x, start=1.2) | |||||
| return enumerate((x, x), start=1.2) | |||||
| x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) | x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) | ||||
| net = Net() | net = Net() | ||||
| with pytest.raises(TypeError) as ex: | with pytest.raises(TypeError) as ex: | ||||
| net((x, x)) | |||||
| net(x) | |||||
| assert "For 'enumerate', the 'start'" in str(ex.value) | assert "For 'enumerate', the 'start'" in str(ex.value) | ||||
| @@ -179,7 +179,8 @@ def test_bprop_with_wrong_output_num(): | |||||
| return BpropWithWrongOutputNum()(x, y) | return BpropWithWrongOutputNum()(x, y) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||||
| grad_all(BpropWithWrongOutputNumCell())(Tensor(np.array(1).astype(np.int32)), | |||||
| Tensor(np.array(2).astype(np.int32))) | |||||
| def test_bprop_with_wrong_output_type(): | def test_bprop_with_wrong_output_type(): | ||||
| context.set_context(check_bprop=True) | context.set_context(check_bprop=True) | ||||
| @@ -25,7 +25,21 @@ from mindspore.ops.functional import depend | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| def test_output_const_tuple(): | |||||
| def test_output_const_tuple_0(): | |||||
| class Net(Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.x = (1, 2, 3) | |||||
| def construct(self): | |||||
| return self.x | |||||
| x = (1, 2, 3) | |||||
| net = Net() | |||||
| assert net() == x | |||||
| def test_output_const_tuple_1(): | |||||
| class Net(Cell): | class Net(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| @@ -83,32 +97,6 @@ def test_output_const_str(): | |||||
| assert net() == "hello world" | assert net() == "hello world" | ||||
| def test_output_parameter_tuple(): | |||||
| class Net(Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| return x | |||||
| x = (1, 2, 3) | |||||
| net = Net() | |||||
| assert net(x) == x | |||||
| def test_output_parameter_list(): | |||||
| class Net(Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| return x | |||||
| x = [1, 2, 3] | |||||
| net = Net() | |||||
| assert net(x) == x | |||||
| def test_output_parameter_int(): | def test_output_parameter_int(): | ||||
| class Net(Cell): | class Net(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -117,7 +105,7 @@ def test_output_parameter_int(): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return x | return x | ||||
| x = 88 | |||||
| x = Tensor(np.array(88).astype(np.int32)) | |||||
| net = Net() | net = Net() | ||||
| assert net(x) == x | assert net(x) == x | ||||
| @@ -126,13 +114,14 @@ def test_output_parameter_str(): | |||||
| class Net(Cell): | class Net(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.x = "hello world" | |||||
| def construct(self, x): | |||||
| return x | |||||
| def construct(self): | |||||
| return self.x | |||||
| x = "hello world" | x = "hello world" | ||||
| net = Net() | net = Net() | ||||
| assert net(x) == x | |||||
| assert net() == x | |||||
| def test_tuple_tuple_0(): | def test_tuple_tuple_0(): | ||||