You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

container.py 16 kB

4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """container"""
  16. from collections import OrderedDict
  17. from abc import abstractmethod
  18. from ..cell import Cell
  19. __all__ = ['SequentialCell', 'CellList']
  20. def _valid_index(cell_num, index, op_name=None):
  21. """Internal function, used to detect the value and type of index."""
  22. msg_prefix = f"For '{op_name}', the" if op_name else "The"
  23. if not isinstance(index, int):
  24. raise TypeError(f"{msg_prefix} type of 'index' should be int, but got {type(index).__name__}.")
  25. if not -cell_num <= index < cell_num:
  26. raise IndexError(f"{msg_prefix} value of 'index' should be a number in range [{-cell_num}, {cell_num}), "
  27. f"but got {index}.")
  28. return index % cell_num
  29. def _valid_cell(cell, op_name=None):
  30. """Internal function, used to check whether the input cell is a subclass of Cell."""
  31. if issubclass(cell.__class__, Cell):
  32. return True
  33. msg_prefix = f"For '{op_name}'," if op_name else ""
  34. raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')
  35. def _get_prefix_and_index(cells):
  36. """get prefix and index of parameter name in sequential cell or cell list."""
  37. prefix = ""
  38. index = 0
  39. if not cells:
  40. return prefix, index
  41. cell_list = list(cells.items())
  42. first_param, first_key = None, None
  43. second_param, second_key = None, None
  44. for key, cell in cell_list:
  45. try:
  46. _, param = next(cell.parameters_and_names())
  47. except StopIteration:
  48. continue
  49. if first_param is None:
  50. first_param = param
  51. first_key = key
  52. continue
  53. second_param = param
  54. second_key = key
  55. break
  56. if first_param is None:
  57. return prefix, index
  58. split_names = first_param.name.split(".")
  59. for idx, name in enumerate(split_names):
  60. if name == first_key:
  61. prefix = ".".join(split_names[:idx])
  62. prefix = prefix + "." if prefix else prefix
  63. index = idx
  64. if second_param is not None and second_param.name.split(".")[idx] == second_key:
  65. break
  66. return prefix, index
  67. class _CellListBase:
  68. """
  69. An interface for base the cell as list.
  70. The sequential cell may be iterated using the construct method using for-in statement.
  71. But there are some scenarios that the construct method built-in does not fit.
  72. For convenience, we provide an interface that indicates the sequential
  73. cell may be interpreted as list of cells, so it can be accessed using
  74. iterator or subscript when a sequential cell instantiate is accessed
  75. by iterator or subscript , it will be interpreted as a list of cells.
  76. """
  77. def __init__(self):
  78. """Initialize _CellListBase."""
  79. self.__cell_as_list__ = True
  80. @abstractmethod
  81. def __len__(self):
  82. pass
  83. @abstractmethod
  84. def __getitem__(self, index):
  85. pass
  86. def construct(self):
  87. raise NotImplementedError
  88. class SequentialCell(Cell):
  89. """
  90. Sequential cell container.
  91. A list of Cells will be added to it in the order they are passed in the constructor.
  92. Alternatively, an ordered dict of cells can also be passed in.
  93. Args:
  94. args (list, OrderedDict): List of subclass of Cell.
  95. Inputs:
  96. - **x** (Tensor) - Tensor with shape according to the first Cell in the sequence.
  97. Outputs:
  98. Tensor, the output Tensor with shape depending on the input `x` and defined sequence of Cells.
  99. Raises:
  100. TypeError: If the type of the `args` is not list or OrderedDict.
  101. Supported Platforms:
  102. ``Ascend`` ``GPU`` ``CPU``
  103. Examples:
  104. >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
  105. >>> relu = nn.ReLU()
  106. >>> seq = nn.SequentialCell([conv, relu])
  107. >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
  108. >>> output = seq(x)
  109. >>> print(output)
  110. [[[[27. 27.]
  111. [27. 27.]]
  112. [[27. 27.]
  113. [27. 27.]]]]
  114. """
  115. def __init__(self, *args):
  116. """Initialize SequentialCell."""
  117. super(SequentialCell, self).__init__()
  118. self._is_dynamic_name = []
  119. if len(args) == 1:
  120. cells = args[0]
  121. if isinstance(cells, list):
  122. for index, cell in enumerate(cells):
  123. self.insert_child_to_cell(str(index), cell)
  124. cell.update_parameters_name(str(index) + ".")
  125. self._is_dynamic_name.append(True)
  126. elif isinstance(cells, OrderedDict):
  127. for name, cell in cells.items():
  128. self.insert_child_to_cell(name, cell)
  129. cell.update_parameters_name(name + ".")
  130. self._is_dynamic_name.append(False)
  131. else:
  132. raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, "
  133. f"but got {type(cells).__name__}")
  134. else:
  135. for index, cell in enumerate(args):
  136. self.insert_child_to_cell(str(index), cell)
  137. cell.update_parameters_name(str(index) + ".")
  138. self._is_dynamic_name.append(True)
  139. self.cell_list = list(self._cells.values())
  140. def __getitem__(self, index):
  141. if isinstance(index, slice):
  142. return self.__class__(
  143. OrderedDict(list(self._cells.items())[index]))
  144. index = _valid_index(len(self), index, self.__class__.__name__)
  145. return list(self._cells.values())[index]
  146. def __setitem__(self, index, cell):
  147. cls_name = self.__class__.__name__
  148. if _valid_cell(cell, cls_name):
  149. prefix, _ = _get_prefix_and_index(self._cells)
  150. index = _valid_index(len(self), index, cls_name)
  151. key = list(self._cells.keys())[index]
  152. self._cells[key] = cell
  153. cell.update_parameters_name(prefix + key + ".")
  154. self.cell_list = list(self._cells.values())
  155. def __delitem__(self, index):
  156. cls_name = self.__class__.__name__
  157. if isinstance(index, int):
  158. index = _valid_index(len(self), index, cls_name)
  159. key = list(self._cells.keys())[index]
  160. del self._cells[key]
  161. del self._is_dynamic_name[index]
  162. elif isinstance(index, slice):
  163. keys = list(self._cells.keys())[index]
  164. for key in keys:
  165. del self._cells[key]
  166. del self._is_dynamic_name[index]
  167. else:
  168. raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, "
  169. f"but got {type(index).__name__}")
  170. prefix, key_index = _get_prefix_and_index(self._cells)
  171. temp_dict = OrderedDict()
  172. for idx, key in enumerate(self._cells.keys()):
  173. cell = self._cells[key]
  174. if self._is_dynamic_name[idx]:
  175. for _, param in cell.parameters_and_names():
  176. param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
  177. temp_dict[str(idx)] = cell
  178. else:
  179. temp_dict[key] = cell
  180. self._cells = temp_dict
  181. self.cell_list = list(self._cells.values())
  182. def __len__(self):
  183. return len(self._cells)
  184. def set_grad(self, flag=True):
  185. self.requires_grad = flag
  186. for cell in self._cells.values():
  187. cell.set_grad(flag)
  188. def append(self, cell):
  189. """
  190. Appends a given cell to the end of the list.
  191. Args:
  192. cell(Cell): The subcell to be appended.
  193. Examples:
  194. >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
  195. >>> bn = nn.BatchNorm2d(2)
  196. >>> relu = nn.ReLU()
  197. >>> seq = nn.SequentialCell([conv, bn])
  198. >>> seq.append(relu)
  199. >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
  200. >>> output = seq(x)
  201. >>> print(output)
  202. [[[[26.999863 26.999863]
  203. [26.999863 26.999863]]
  204. [[26.999863 26.999863]
  205. [26.999863 26.999863]]]]
  206. """
  207. if _valid_cell(cell, self.__class__.__name__):
  208. prefix, _ = _get_prefix_and_index(self._cells)
  209. cell.update_parameters_name(prefix + str(len(self)) + ".")
  210. self._is_dynamic_name.append(True)
  211. self._cells[str(len(self))] = cell
  212. self.cell_list = list(self._cells.values())
  213. def construct(self, input_data):
  214. for cell in self.cell_list:
  215. input_data = cell(input_data)
  216. return input_data
  217. class CellList(_CellListBase, Cell):
  218. """
  219. Holds Cells in a list.
  220. CellList can be used like a regular Python list, support
  221. '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
  222. but cells it contains are properly registered, and will be visible by all Cell methods.
  223. Args:
  224. args (list, optional): List of subclass of Cell.
  225. Supported Platforms:
  226. ``Ascend`` ``GPU`` ``CPU``
  227. Examples:
  228. >>> conv = nn.Conv2d(100, 20, 3)
  229. >>> bn = nn.BatchNorm2d(20)
  230. >>> relu = nn.ReLU()
  231. >>> cell_ls = nn.CellList([bn])
  232. >>> cell_ls.insert(0, conv)
  233. >>> cell_ls.append(relu)
  234. >>> print(cell_ls)
  235. CellList<
  236. (0): Conv2d<input_channels=100, output_channels=20, kernel_size=(3, 3),stride=(1, 1), pad_mode=same,
  237. padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
  238. (1): BatchNorm2d<num_features=20, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=1.gamma,
  239. shape=(20,), dtype=Float32, requires_grad=True), beta=Parameter (name=1.beta, shape=(20,), dtype=Float32,
  240. requires_grad=True), moving_mean=Parameter (name=1.moving_mean, shape=(20,), dtype=Float32,
  241. requires_grad=False), moving_variance=Parameter (name=1.moving_variance, shape=(20,), dtype=Float32,
  242. requires_grad=False)>
  243. (2): ReLU<>
  244. >
  245. """
  246. def __init__(self, *args, **kwargs):
  247. """Initialize CellList."""
  248. auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
  249. _CellListBase.__init__(self)
  250. Cell.__init__(self, auto_prefix)
  251. if len(args) == 1:
  252. self.extend(args[0])
  253. def __getitem__(self, index):
  254. cls_name = self.__class__.__name__
  255. if isinstance(index, slice):
  256. return self.__class__(list(self._cells.values())[index])
  257. if isinstance(index, int):
  258. index = _valid_index(len(self), index, cls_name)
  259. return self._cells[str(index)]
  260. raise TypeError(f"For '{cls_name}', the type of 'index' should be int or slice, "
  261. f"but got {type(index).__name__}.")
  262. def __setitem__(self, index, cell):
  263. cls_name = self.__class__.__name__
  264. if not isinstance(index, int) and _valid_cell(cell, cls_name):
  265. raise TypeError(f"For '{cls_name}', the type of 'index' should be int, "
  266. f"but got {type(index).__name__}.")
  267. index = _valid_index(len(self), index, cls_name)
  268. if self._auto_prefix:
  269. prefix, _ = _get_prefix_and_index(self._cells)
  270. cell.update_parameters_name(prefix + str(index) + ".")
  271. self._cells[str(index)] = cell
  272. def __delitem__(self, index):
  273. cls_name = self.__class__.__name__
  274. if isinstance(index, int):
  275. index = _valid_index(len(self), index, cls_name)
  276. del self._cells[str(index)]
  277. elif isinstance(index, slice):
  278. keys = list(self._cells.keys())[index]
  279. for key in keys:
  280. del self._cells[key]
  281. else:
  282. raise TypeError(f"For '{cls_name}', the type of 'index' should be int or slice, "
  283. f"but got {type(index).__name__}.")
  284. # adjust orderedDict
  285. prefix, key_index = _get_prefix_and_index(self._cells)
  286. temp_dict = OrderedDict()
  287. for idx, cell in enumerate(self._cells.values()):
  288. if self._auto_prefix:
  289. for _, param in cell.parameters_and_names():
  290. param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
  291. temp_dict[str(idx)] = cell
  292. self._cells = temp_dict
  293. def __len__(self):
  294. return len(self._cells)
  295. def __iter__(self):
  296. return iter(self._cells.values())
  297. def __iadd__(self, cells):
  298. self.extend(cells)
  299. return self
  300. def insert(self, index, cell):
  301. """
  302. Inserts a given cell before a given index in the list.
  303. Args:
  304. index(int): The Insert index in the CellList.
  305. cell(Cell): The subcell to be inserted.
  306. """
  307. cls_name = self.__class__.__name__
  308. idx = _valid_index(len(self), index, cls_name)
  309. _valid_cell(cell, cls_name)
  310. length = len(self)
  311. prefix, key_index = _get_prefix_and_index(self._cells)
  312. while length > idx:
  313. if self._auto_prefix:
  314. tmp_cell = self._cells[str(length-1)]
  315. for _, param in tmp_cell.parameters_and_names():
  316. param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
  317. self._cells[str(length)] = self._cells[str(length - 1)]
  318. length -= 1
  319. self._cells[str(idx)] = cell
  320. if self._auto_prefix:
  321. cell.update_parameters_name(prefix + str(idx) + ".")
  322. def extend(self, cells):
  323. """
  324. Appends cells from a Python iterable to the end of the list.
  325. Args:
  326. cells(list): The subcells to be extended.
  327. Raises:
  328. TypeError: If the cells are not a list of subcells.
  329. """
  330. cls_name = self.__class__.__name__
  331. if not isinstance(cells, list):
  332. raise TypeError(f"For '{cls_name}', the new cells wanted to append "
  333. f"should be instance of list, but got {type(cells).__name__}.")
  334. prefix, _ = _get_prefix_and_index(self._cells)
  335. for cell in cells:
  336. if _valid_cell(cell, cls_name):
  337. if self._auto_prefix:
  338. cell.update_parameters_name(prefix + str(len(self)) + ".")
  339. self._cells[str(len(self))] = cell
  340. return self
  341. def append(self, cell):
  342. """
  343. Appends a given cell to the end of the list.
  344. Args:
  345. cell(Cell): The subcell to be appended.
  346. """
  347. if _valid_cell(cell, self.__class__.__name__):
  348. if self._auto_prefix:
  349. prefix, _ = _get_prefix_and_index(self._cells)
  350. cell.update_parameters_name(prefix + str(len(self)) + ".")
  351. self._cells[str(len(self))] = cell
  352. def set_grad(self, flag=True):
  353. self.requires_grad = flag
  354. for cell in self._cells.values():
  355. cell.set_grad(flag)
  356. def construct(self, *inputs):
  357. raise NotImplementedError