|
|
|
@@ -146,9 +146,20 @@ class SequentialCell(Cell): |
|
|
|
cell.set_grad(flag) |
|
|
|
|
|
|
|
def append(self, cell): |
|
|
|
"""Appends a given cell to the end of the list.""" |
|
|
|
"""Appends a given cell to the end of the list. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid') |
|
|
|
>>> bn = nn.BatchNorm2d(2) |
|
|
|
>>> relu = nn.ReLU() |
|
|
|
>>> seq = nn.SequentialCell([conv, bn]) |
|
|
|
>>> seq.append(relu) |
|
|
|
>>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32) |
|
|
|
>>> seq(x) |
|
|
|
""" |
|
|
|
if _valid_cell(cell): |
|
|
|
self._cells[str(len(self))] = cell |
|
|
|
self.cell_list = list(self._cells.values()) |
|
|
|
return self |
|
|
|
|
|
|
|
def construct(self, input_data): |
|
|
|
|