Browse Source

remove the 'raise' in construct of Cell

tags/v1.1.0
buxue 5 years ago
parent
commit
aeeef7607a
6 changed files with 9 additions and 15 deletions
  1. +4
    -1
      mindspore/nn/cell.py
  2. +1
    -2
      tests/ut/python/nn/test_cell.py
  3. +1
    -3
      tests/ut/python/ops/test_ops_check.py
  4. +1
    -3
      tests/ut/python/pipeline/parse/test_grammar_constraints.py
  5. +1
    -4
      tests/ut/python/pipeline/parse/test_super.py
  6. +1
    -2
      tests/ut/python/pynative_mode/nn/test_cell.py

+ 4
- 1
mindspore/nn/cell.py View File

@@ -315,6 +315,9 @@ class Cell(Cell_):
return tuple(res) return tuple(res)


def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"will call the super class(Cell) 'construct'.")
if kwargs: if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args inputs = bound_args.args
@@ -681,7 +684,7 @@ class Cell(Cell_):
Returns: Returns:
Tensor, returns the computed result. Tensor, returns the computed result.
""" """
raise NotImplementedError
return None


def init_parameters_data(self, auto_parallel_mode=False): def init_parameters_data(self, auto_parallel_mode=False):
""" """


+ 1
- 2
tests/ut/python/nn/test_cell.py View File

@@ -197,8 +197,7 @@ def test_exceptions():
ModError2(t) ModError2(t)


m = nn.Cell() m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None




def test_cell_copy(): def test_cell_copy():


+ 1
- 3
tests/ut/python/ops/test_ops_check.py View File

@@ -63,9 +63,7 @@ def test_net_without_construct():
""" test_net_without_construct """ """ test_net_without_construct """
net = NetMissConstruct() net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp)
assert "Unsupported syntax 'Raise' at " in str(err.value)
_executor.compile(net, inp)




class NetWithRaise(nn.Cell): class NetWithRaise(nn.Cell):


+ 1
- 3
tests/ut/python/pipeline/parse/test_grammar_constraints.py View File

@@ -196,6 +196,4 @@ def test_missing_construct():
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input) tensor = Tensor(np_input)
net = NetMissConstruct() net = NetMissConstruct()
with pytest.raises(RuntimeError) as er:
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)
assert net(tensor) is None

+ 1
- 4
tests/ut/python/pipeline/parse/test_super.py View File

@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
""" test super""" """ test super"""
import numpy as np import numpy as np
import pytest


import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
@@ -108,9 +107,7 @@ def test_super_cell():
net = Net(2) net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32)) x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)
assert net(x, y) is None




def test_single_super_in(): def test_single_super_in():


+ 1
- 2
tests/ut/python/pynative_mode/nn/test_cell.py View File

@@ -212,8 +212,7 @@ def test_exceptions():
ModError2(t) ModError2(t)


m = nn.Cell() m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None




def test_del(): def test_del():


Loading…
Cancel
Save