Browse Source

remove the 'raise' in construct of Cell

tags/v1.1.0
buxue 5 years ago
parent
commit
aeeef7607a
6 changed files with 9 additions and 15 deletions
  1. +4
    -1
      mindspore/nn/cell.py
  2. +1
    -2
      tests/ut/python/nn/test_cell.py
  3. +1
    -3
      tests/ut/python/ops/test_ops_check.py
  4. +1
    -3
      tests/ut/python/pipeline/parse/test_grammar_constraints.py
  5. +1
    -4
      tests/ut/python/pipeline/parse/test_super.py
  6. +1
    -2
      tests/ut/python/pynative_mode/nn/test_cell.py

+ 4
- 1
mindspore/nn/cell.py View File

@@ -315,6 +315,9 @@ class Cell(Cell_):
return tuple(res)

def __call__(self, *inputs, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"will call the super class(Cell) 'construct'.")
if kwargs:
bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
inputs = bound_args.args
@@ -681,7 +684,7 @@ class Cell(Cell_):
Returns:
Tensor, returns the computed result.
"""
raise NotImplementedError
return None

def init_parameters_data(self, auto_parallel_mode=False):
"""


+ 1
- 2
tests/ut/python/nn/test_cell.py View File

@@ -197,8 +197,7 @@ def test_exceptions():
ModError2(t)

m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None


def test_cell_copy():


+ 1
- 3
tests/ut/python/ops/test_ops_check.py View File

@@ -63,9 +63,7 @@ def test_net_without_construct():
""" test_net_without_construct """
net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
with pytest.raises(RuntimeError) as err:
_executor.compile(net, inp)
assert "Unsupported syntax 'Raise' at " in str(err.value)
_executor.compile(net, inp)


class NetWithRaise(nn.Cell):


+ 1
- 3
tests/ut/python/pipeline/parse/test_grammar_constraints.py View File

@@ -196,6 +196,4 @@ def test_missing_construct():
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
with pytest.raises(RuntimeError) as er:
net(tensor)
assert "Unsupported syntax 'Raise' at " in str(er.value)
assert net(tensor) is None

+ 1
- 4
tests/ut/python/pipeline/parse/test_super.py View File

@@ -14,7 +14,6 @@
# ============================================================================
""" test super"""
import numpy as np
import pytest

import mindspore.nn as nn
from mindspore import Tensor
@@ -108,9 +107,7 @@ def test_super_cell():
net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)
assert net(x, y) is None


def test_single_super_in():


+ 1
- 2
tests/ut/python/pynative_mode/nn/test_cell.py View File

@@ -212,8 +212,7 @@ def test_exceptions():
ModError2(t)

m = nn.Cell()
with pytest.raises(NotImplementedError):
m.construct()
assert m.construct() is None


def test_del():


Loading…
Cancel
Save