Browse Source

!6192 Improve the functions of class SequentialCell and add examples

Merge pull request !6192 from lijiaqi/sequentialcell
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d0e49c5cf8
2 changed files with 26 additions and 1 deletions
  1. +12
    -1
      mindspore/nn/layer/container.py
  2. +14
    -0
      tests/ut/python/nn/test_container.py

+ 12
- 1
mindspore/nn/layer/container.py View File

@@ -146,9 +146,20 @@ class SequentialCell(Cell):
cell.set_grad(flag)

def append(self, cell):
"""Appends a given cell to the end of the list."""
"""Appends a given cell to the end of the list.

Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
>>> bn = nn.BatchNorm2d(2)
>>> relu = nn.ReLU()
>>> seq = nn.SequentialCell([conv, bn])
>>> seq.append(relu)
>>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
>>> seq(x)
"""
if _valid_cell(cell):
self._cells[str(len(self))] = cell
self.cell_list = list(self._cells.values())
return self

def construct(self, input_data):


+ 14
- 0
tests/ut/python/nn/test_container.py View File

@@ -84,6 +84,20 @@ class TestSequentialCell():
del m[:]
assert type(m).__name__ == 'SequentialCell'

def test_sequentialcell_append(self):
input_np = np.ones((1, 3)).astype(np.float32)
input_me = Tensor(input_np)
relu = nn.ReLU()
tanh = nn.Tanh()
seq = nn.SequentialCell([relu])
seq.append(tanh)
out_me = seq(input_me)

seq1 = nn.SequentialCell([relu, tanh])
out = seq1(input_me)

assert out[0][0] == out_me[0][0]


class TestCellList():
""" TestCellList """


Loading…
Cancel
Save