You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

container.py 9.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """container"""
  16. from collections import OrderedDict
  17. from abc import abstractmethod
  18. from ..cell import Cell
  19. __all__ = ['SequentialCell', 'CellList']
  20. def _valid_index(cell_num, index):
  21. if not isinstance(index, int):
  22. raise TypeError("Index {} is not int type")
  23. if not -cell_num <= index < cell_num:
  24. raise IndexError("Index should be a number in range [{}, {}), but got {}"
  25. .format(-cell_num, cell_num, index))
  26. return index % cell_num
  27. def _valid_cell(cell):
  28. if issubclass(cell.__class__, Cell):
  29. return True
  30. raise TypeError('Cell {} is not subclass of Cell'.format(cell))
  31. class _CellListBase():
  32. """
  33. An interface for base the cell as list.
  34. The sequential cell may be iterated using the construct method using for-in statement.
  35. But there are some scenarios that the construct method built-in does not fit.
  36. For convenience, we provide an interface that indicates the sequential
  37. cell may be interpretated as list of cells, so it can be accessed using
  38. iterator or subscript when a sequential cell instantiate is accessed
  39. by iterator or subscript , it will be interpretated as a list of cells.
  40. """
  41. def __init__(self):
  42. self.__cell_as_list__ = True
  43. @abstractmethod
  44. def __len__(self):
  45. pass
  46. @abstractmethod
  47. def __getitem__(self, index):
  48. pass
  49. def construct(self):
  50. raise NotImplementedError
  51. class SequentialCell(Cell):
  52. """
  53. Sequential cell container.
  54. A list of Cells will be added to it in the order they are passed in the constructor.
  55. Alternatively, an ordered dict of cells can also be passed in.
  56. Args:
  57. args (list, OrderedDict): List of subclass of Cell.
  58. Raises:
  59. TypeError: If the type of the argument is not list or OrderedDict.
  60. Inputs:
  61. - **input** (Tensor) - Tensor with shape according to the first Cell in the sequence.
  62. Outputs:
  63. Tensor, the output Tensor with shape depending on the input and defined sequence of Cells.
  64. Examples:
  65. >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
  66. >>> bn = nn.BatchNorm2d(2)
  67. >>> relu = nn.ReLU()
  68. >>> seq = nn.SequentialCell([conv, bn, relu])
  69. >>>
  70. >>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
  71. >>> seq(x)
  72. [[[[0.02531557 0. ]
  73. [0.04933941 0.04880078]]
  74. [[0. 0. ]
  75. [0. 0. ]]]]
  76. """
  77. def __init__(self, *args):
  78. super(SequentialCell, self).__init__()
  79. if len(args) == 1:
  80. cells = args[0]
  81. if isinstance(cells, list):
  82. for index, cell in enumerate(cells):
  83. self.insert_child_to_cell(str(index), cell)
  84. elif isinstance(cells, OrderedDict):
  85. for name, cell in cells.items():
  86. self.insert_child_to_cell(name, cell)
  87. else:
  88. raise TypeError('Cells must be list or orderedDict')
  89. else:
  90. for index, cell in enumerate(args):
  91. self.insert_child_to_cell(str(index), cell)
  92. self.cell_list = list(self._cells.values())
  93. def __getitem__(self, index):
  94. if isinstance(index, slice):
  95. return self.__class__(
  96. OrderedDict(list(self._cells.items())[index]))
  97. index = _valid_index(len(self), index)
  98. return list(self._cells.values())[index]
  99. def __setitem__(self, index, cell):
  100. if _valid_cell(cell):
  101. index = _valid_index(len(self), index)
  102. key = list(self._cells.keys())[index]
  103. self._cells[key] = cell
  104. self.cell_list = list(self._cells.values())
  105. def __delitem__(self, index):
  106. if isinstance(index, int):
  107. index = _valid_index(len(self), index)
  108. key = list(self._cells.keys())[index]
  109. del self._cells[key]
  110. elif isinstance(index, slice):
  111. keys = list(self._cells.keys())[index]
  112. for key in keys:
  113. del self._cells[key]
  114. else:
  115. raise TypeError('Index {} is not int type or slice type'.format(index))
  116. self.cell_list = list(self._cells.values())
  117. def __len__(self):
  118. return len(self._cells)
  119. def set_grad(self, flag=True):
  120. self.requires_grad = flag
  121. for cell in self._cells.values():
  122. cell.set_grad(flag)
  123. def append(self, cell):
  124. """Appends a given cell to the end of the list.
  125. Examples:
  126. >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
  127. >>> bn = nn.BatchNorm2d(2)
  128. >>> relu = nn.ReLU()
  129. >>> seq = nn.SequentialCell([conv, bn])
  130. >>> seq.append(relu)
  131. >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
  132. >>> seq(x)
  133. [[[[0.12445523 0.12445523]
  134. [0.12445523 0.12445523]]
  135. [[0. 0. ]
  136. [0. 0. ]]]]
  137. """
  138. if _valid_cell(cell):
  139. self._cells[str(len(self))] = cell
  140. self.cell_list = list(self._cells.values())
  141. return self
  142. def construct(self, input_data):
  143. for cell in self.cell_list:
  144. input_data = cell(input_data)
  145. return input_data
  146. class CellList(_CellListBase, Cell):
  147. """
  148. Holds Cells in a list.
  149. CellList can be used like a regular Python list, support
  150. '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
  151. but cells it contains are properly registered, and will be visible by all Cell methods.
  152. Args:
  153. args (list, optional): List of subclass of Cell.
  154. Examples:
  155. >>> conv = nn.Conv2d(100, 20, 3)
  156. >>> bn = nn.BatchNorm2d(20)
  157. >>> relu = nn.ReLU()
  158. >>> cell_ls = nn.CellList([bn])
  159. >>> cell_ls.insert(0, conv)
  160. >>> cell_ls.append(relu)
  161. >>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
  162. >>> # not same as nn.SequentialCell, `cell_ls(x)` is not correct
  163. >>> cell_ls
  164. CellList< (0): Conv2d<input_channels=100, ..., bias_init=None>
  165. (1): BatchNorm2d<num_features=20, ..., moving_variance=Parameter (name=variance)>
  166. (2): ReLU<> >
  167. """
  168. def __init__(self, *args):
  169. _CellListBase.__init__(self)
  170. Cell.__init__(self)
  171. if len(args) == 1:
  172. self.extend(args[0])
  173. def __getitem__(self, index):
  174. if isinstance(index, slice):
  175. return self.__class__(list(self._cells.values())[index])
  176. if isinstance(index, int):
  177. index = _valid_index(len(self), index)
  178. return self._cells[str(index)]
  179. raise TypeError('Index {} is not int type or slice type'.format(index))
  180. def __setitem__(self, index, cell):
  181. if not isinstance(index, int) and _valid_cell(cell):
  182. raise TypeError('Index {} is not int type'.format(index))
  183. index = _valid_index(len(self), index)
  184. self._cells[str(index)] = cell
  185. def __delitem__(self, index):
  186. if isinstance(index, int):
  187. index = _valid_index(len(self), index)
  188. del self._cells[str(index)]
  189. elif isinstance(index, slice):
  190. keys = list(self._cells.keys())[index]
  191. for key in keys:
  192. del self._cells[key]
  193. else:
  194. raise TypeError('Index {} is not int type or slice type'.format(index))
  195. # adjust orderedDict
  196. temp_dict = OrderedDict()
  197. for idx, cell in enumerate(self._cells.values()):
  198. temp_dict[str(idx)] = cell
  199. self._cells = temp_dict
  200. def __len__(self):
  201. return len(self._cells)
  202. def __iter__(self):
  203. return iter(self._cells.values())
  204. def __iadd__(self, cells):
  205. self.extend(cells)
  206. return self
  207. def insert(self, index, cell):
  208. """Inserts a given cell before a given index in the list."""
  209. idx = _valid_index(len(self), index)
  210. _valid_cell(cell)
  211. length = len(self)
  212. while length > idx:
  213. self._cells[str(length)] = self._cells[str(length - 1)]
  214. length -= 1
  215. self._cells[str(idx)] = cell
  216. def extend(self, cells):
  217. """
  218. Appends cells from a Python iterable to the end of the list.
  219. Raises:
  220. TypeError: If the cells are not a list of subcells.
  221. """
  222. if not isinstance(cells, list):
  223. raise TypeError('Cells {} should be list of subcells'.format(cells))
  224. for cell in cells:
  225. if _valid_cell(cell):
  226. self._cells[str(len(self))] = cell
  227. return self
  228. def append(self, cell):
  229. """Appends a given cell to the end of the list."""
  230. if _valid_cell(cell):
  231. self._cells[str(len(self))] = cell
  232. return self
  233. def set_grad(self, flag=True):
  234. self.requires_grad = flag
  235. for cell in self._cells.values():
  236. cell.set_grad(flag)
  237. def construct(self, *inputs):
  238. raise NotImplementedError