| @@ -315,6 +315,9 @@ class Cell(Cell_): | |||||
| return tuple(res) | return tuple(res) | ||||
| def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
| if self.__class__.construct is Cell.construct: | |||||
| logger.warning(f"The '{self.__class__}' does not override the method 'construct', " | |||||
| f"will call the super class(Cell) 'construct'.") | |||||
| if kwargs: | if kwargs: | ||||
| bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) | bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) | ||||
| inputs = bound_args.args | inputs = bound_args.args | ||||
| @@ -681,7 +684,7 @@ class Cell(Cell_): | |||||
| Returns: | Returns: | ||||
| Tensor, returns the computed result. | Tensor, returns the computed result. | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| return None | |||||
| def init_parameters_data(self, auto_parallel_mode=False): | def init_parameters_data(self, auto_parallel_mode=False): | ||||
| """ | """ | ||||
| @@ -197,8 +197,7 @@ def test_exceptions(): | |||||
| ModError2(t) | ModError2(t) | ||||
| m = nn.Cell() | m = nn.Cell() | ||||
| with pytest.raises(NotImplementedError): | |||||
| m.construct() | |||||
| assert m.construct() is None | |||||
| def test_cell_copy(): | def test_cell_copy(): | ||||
| @@ -63,9 +63,7 @@ def test_net_without_construct(): | |||||
| """ test_net_without_construct """ | """ test_net_without_construct """ | ||||
| net = NetMissConstruct() | net = NetMissConstruct() | ||||
| inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | ||||
| with pytest.raises(RuntimeError) as err: | |||||
| _executor.compile(net, inp) | |||||
| assert "Unsupported syntax 'Raise' at " in str(err.value) | |||||
| _executor.compile(net, inp) | |||||
| class NetWithRaise(nn.Cell): | class NetWithRaise(nn.Cell): | ||||
| @@ -196,6 +196,4 @@ def test_missing_construct(): | |||||
| np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) | np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) | ||||
| tensor = Tensor(np_input) | tensor = Tensor(np_input) | ||||
| net = NetMissConstruct() | net = NetMissConstruct() | ||||
| with pytest.raises(RuntimeError) as er: | |||||
| net(tensor) | |||||
| assert "Unsupported syntax 'Raise' at " in str(er.value) | |||||
| assert net(tensor) is None | |||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test super""" | """ test super""" | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -108,9 +107,7 @@ def test_super_cell(): | |||||
| net = Net(2) | net = Net(2) | ||||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | x = Tensor(np.ones([1, 2, 3], np.int32)) | ||||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | y = Tensor(np.ones([1, 2, 3], np.int32)) | ||||
| with pytest.raises(RuntimeError) as er: | |||||
| net(x, y) | |||||
| assert "Unsupported syntax 'Raise'" in str(er.value) | |||||
| assert net(x, y) is None | |||||
| def test_single_super_in(): | def test_single_super_in(): | ||||
| @@ -212,8 +212,7 @@ def test_exceptions(): | |||||
| ModError2(t) | ModError2(t) | ||||
| m = nn.Cell() | m = nn.Cell() | ||||
| with pytest.raises(NotImplementedError): | |||||
| m.construct() | |||||
| assert m.construct() is None | |||||
| def test_del(): | def test_del(): | ||||