You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cell.py 40 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cell"""
  16. import inspect
  17. import time
  18. import gc
  19. import os
  20. from collections import OrderedDict
  21. import numpy
  22. from mindspore import log as logger
  23. from .. import context
  24. from ..common import dtype as mstype
  25. from ..common.api import _executor, _pynative_exec
  26. from .._checkparam import _check_str_by_regular
  27. from ..common.parameter import Parameter, ParameterTuple
  28. from .._c_expression import init_backend, Cell_
  29. from ..ops.primitive import Primitive
  30. from ..ops.operations import HookBackward
  31. from ..ops.functional import cast
  32. from ..parallel._tensor import _load_tensor_by_layout
  33. from ..common.tensor import Tensor
  34. class Cell(Cell_):
  35. """
  36. Base class for all neural networks.
  37. A 'Cell' could be a single neural network cell, such as conv2d, relu, batch_norm, etc. or a composition of
  38. cells to constructing a network.
  39. Note:
  40. In general, the autograd algorithm will automatically generate the implementation of the gradient function,
  41. but if bprop method is implemented, the gradient function
  42. will be replaced by the bprop. The bprop implementation will receive a Tensor `dout` containing the gradient
  43. of the loss w.r.t. the output, and a Tensor `out` containing the forward result. The bprop needs to compute the
  44. gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported
  45. currently.
  46. Args:
  47. auto_prefix (bool): Recursively generate namespaces. Default: True.
  48. Examples:
  49. >>> class MyCell(Cell):
  50. >>> def __init__(self):
  51. >>> super(MyCell, self).__init__()
  52. >>> self.relu = P.ReLU()
  53. >>>
  54. >>> def construct(self, x):
  55. >>> return self.relu(x)
  56. """
  57. IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
  58. '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
  59. '_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
  60. '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
  61. '_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
  62. '_auto_parallel_compile_and_run', 'cell_type']
  63. def __init__(self, auto_prefix=True, flags=None):
  64. Cell_.__init__(self, self._cell_tag)
  65. self._params = OrderedDict()
  66. self._cells = OrderedDict()
  67. self._params_list = OrderedDict()
  68. self._tensor_list = OrderedDict()
  69. self.training = False
  70. self.requires_grad = False
  71. self.pynative = False
  72. self._attr_synced = False
  73. self._param_prefix = ''
  74. self._auto_prefix = auto_prefix
  75. self._scope = None
  76. self._phase = 'train'
  77. self._parameter_layout_dict = {}
  78. self._create_time = int(time.time() * 1e9)
  79. init_backend()
  80. # call gc to release GE session resources used by non-used cell objects
  81. if os.getenv('GC_COLLECT_IN_CELL') == '1':
  82. gc.collect()
  83. self._construct_inputs_num = 0
  84. self._construct_inputs_names = []
  85. self._auto_parallel_mode = False
  86. self._parallel_inputs_run = None
  87. if flags:
  88. self.add_flags(**flags)
  89. self._backward_hook = None
  90. self.enable_hook = False
  91. self._bprop_debug = False
  92. self._already_run = False
  93. self.cell_type = None
  94. self._auto_parallel_compile_and_run = False
  95. @property
  96. def already_run(self):
  97. return self._already_run
  98. def __getstate__(self):
  99. base = Cell_.__getstate__(self)
  100. return base, self.__dict__
  101. def __setstate__(self, state):
  102. base, dict_ = state
  103. Cell_.__setstate__(self, base)
  104. self.__dict__ = dict_
  105. self._attr_synced = False
  106. @property
  107. def _cell_tag(self):
  108. # `<class 'xxxxxxx'>`
  109. # -> `xxxxxxx`
  110. return str(self.__class__)[8:-2]
  111. @already_run.setter
  112. def already_run(self, value):
  113. self._already_run = value
  114. @property
  115. def create_time(self):
  116. return self._create_time
  117. @property
  118. def cell_init_args(self):
  119. return self._cell_init_args
  120. @property
  121. def param_prefix(self):
  122. """
  123. Param prefix is the prefix of current cell's direct child parameter.
  124. """
  125. return self._param_prefix
  126. @property
  127. def bprop_debug(self):
  128. """
  129. Get whether cell custom bprop debug is enabled.
  130. """
  131. return self._bprop_debug
  132. @bprop_debug.setter
  133. def bprop_debug(self, value):
  134. """
  135. Set whether to enable cell custom bprop debug.
  136. Note:
  137. When bprop is defined in cell, the bprop function will be executed
  138. in python interpreter when bprop debug is true, and will be parsed
  139. and add to graph when bprop debug is false.
  140. Args:
  141. value (bool): Specifies whether to enable bprop debug. Default: False.
  142. """
  143. if not isinstance(value, bool):
  144. raise TypeError("'bprop debug' value must be bool type.")
  145. self._bprop_debug = value
  146. def update_cell_prefix(self):
  147. """
  148. Update the all child cells' self.param_prefix.
  149. After being invoked, it can get all the cell's children's name prefix by '_param_prefix'.
  150. """
  151. cells_name = self.cells_and_names()
  152. for cell_name, cell in cells_name:
  153. cell._param_prefix = cell_name
  154. def update_cell_type(self, cell_type):
  155. """
  156. The current cell type is updated when a quantization aware training network is encountered.
  157. After being invoked, it can set the cell type to 'cell_type'.
  158. """
  159. self.cell_type = cell_type
  160. @cell_init_args.setter
  161. def cell_init_args(self, value):
  162. if not isinstance(value, str):
  163. raise TypeError("'cell_init_args' must be string type.")
  164. self._cell_init_args = value
  165. @property
  166. def phase(self):
  167. return self._phase
  168. @phase.setter
  169. def phase(self, value):
  170. if not isinstance(value, str):
  171. raise TypeError("'phase' must be string type.")
  172. self._phase = value
  173. @property
  174. def parameter_layout_dict(self):
  175. return self._parameter_layout_dict
  176. @property
  177. def cls_name(self):
  178. return self.__class__.__name__
  179. @parameter_layout_dict.setter
  180. def parameter_layout_dict(self, value):
  181. if not isinstance(value, dict):
  182. raise TypeError("'parameter_layout_dict' must be dict type.")
  183. self._parameter_layout_dict = value
  184. def get_func_graph_proto(self):
  185. """Return graph binary proto."""
  186. return _executor._get_func_graph_proto(self.phase + "." + str(self.create_time), "anf_ir", True)
  187. def __getattr__(self, name):
  188. if '_params' in self.__dict__:
  189. params = self.__dict__['_params']
  190. if name in params:
  191. if context.get_context("mode") == context.PYNATIVE_MODE:
  192. return self.cast_param(params[name])
  193. return params[name]
  194. if '_cells' in self.__dict__:
  195. cells = self.__dict__['_cells']
  196. if name in cells:
  197. return cells[name]
  198. if '_tensor_list' in self.__dict__:
  199. tensor_list = self.__dict__['_tensor_list']
  200. if name in tensor_list:
  201. return self.cast_param(tensor_list[name])
  202. if '_params_list' in self.__dict__:
  203. params_list = self.__dict__['_params_list']
  204. if name in params_list:
  205. para_list = params_list[name]
  206. cast_list = list()
  207. for para in para_list:
  208. cast_list.append(self.cast_param(para))
  209. para_list = ParameterTuple(cast_list)
  210. return para_list
  211. raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
  212. def __del__(self):
  213. _pynative_exec.clear(str(id(self)))
  214. if hasattr(self, "_create_time"):
  215. _executor.del_net_res(str(self._create_time))
  216. def __delattr__(self, name):
  217. if name in self._params:
  218. del self._params[name]
  219. elif name in self._cells:
  220. del self._cells[name]
  221. else:
  222. if '_params_list' in self.__dict__ and name in self._params_list:
  223. del self._params_list[name]
  224. elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
  225. del self._tensor_list[name]
  226. object.__delattr__(self, name)
  227. self._attr_synced = False
  228. def cast_inputs(self, inputs, dst_type):
  229. res = list()
  230. for item in inputs:
  231. if isinstance(item, tuple):
  232. res.append(self.cast_inputs(item, dst_type))
  233. else:
  234. res.append(cast(item, dst_type))
  235. return tuple(res)
  236. def __call__(self, *inputs, **kwargs):
  237. if context.get_context("mode") == context.GRAPH_MODE:
  238. if kwargs:
  239. raise ValueError("For 'graph' mode, the outermost network does not support passing "
  240. "key-value pair parameters and variable key-value pair parameters.")
  241. if self.enable_hook:
  242. raise ValueError("The graph mode does not support hook function.")
  243. out = self.compile_and_run(*inputs)
  244. return out
  245. if kwargs:
  246. bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs)
  247. inputs = bound_args.args
  248. kwargs = bound_args.kwargs
  249. for item in inputs:
  250. if isinstance(item, numpy.ndarray):
  251. raise TypeError("cell inputs should not be numpy array.")
  252. origin_grad = []
  253. if self.requires_grad is True:
  254. _pynative_exec.set_grad_flag(True)
  255. _pynative_exec.new_graph(self, *inputs, **kwargs)
  256. for cell in self.cells():
  257. origin_grad.append(cell.requires_grad)
  258. cell.set_grad(True)
  259. else:
  260. _pynative_exec.set_grad_flag(False)
  261. cast_inputs = list()
  262. if hasattr(self, "_mindspore_flags"):
  263. if self._mindspore_flags.get('fp16'):
  264. cast_inputs = self.cast_inputs(inputs, mstype.float16)
  265. if self._mindspore_flags.get('fp32'):
  266. cast_inputs = self.cast_inputs(inputs, mstype.float32)
  267. if not cast_inputs:
  268. cast_inputs = inputs
  269. if self.enable_hook:
  270. output = self._hook_construct(*cast_inputs, **kwargs)
  271. else:
  272. output = self.construct(*cast_inputs, **kwargs)
  273. if isinstance(output, Parameter):
  274. output = output.data
  275. if self.requires_grad is True:
  276. _pynative_exec.end_graph(self, output, *inputs, **kwargs)
  277. for i, cell in enumerate(self.cells()):
  278. cell.set_grad(origin_grad[i])
  279. self._already_run = True
  280. return output
  281. def _add_attr(self, name, value):
  282. if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
  283. super(Cell, self)._add_attr(name, value)
  284. def _sync_attr_for_compile(self):
  285. """Sync the attr to c++ object."""
  286. if self._attr_synced:
  287. return
  288. cells = self.__dict__.get('_cells')
  289. for key in cells:
  290. cell = cells[key]
  291. cell._sync_attr_for_compile()
  292. self._add_attr(key, cell)
  293. params = self.__dict__.get('_params')
  294. for key in params:
  295. if '.' in key:
  296. continue
  297. param = params[key]
  298. self._add_attr(key, param)
  299. params_list = self.__dict__.get('_params_list')
  300. for key in params_list:
  301. params_list_item = params_list[key]
  302. self._add_attr(key, params_list_item)
  303. for key in self.__dict__:
  304. value = self.__dict__[key]
  305. self._add_attr(key, value)
  306. self._attr_synced = True
  307. def __setattr__(self, name, value):
  308. cells = self.__dict__.get('_cells')
  309. params = self.__dict__.get('_params')
  310. params_list = self.__dict__.get('_params_list')
  311. tensor_list = self.__dict__.get('_tensor_list')
  312. if isinstance(value, Parameter):
  313. if params is None:
  314. raise AttributeError("Can not assign params before Cell.__init__() call.")
  315. if name in self.__dict__:
  316. if self.__dict__[name] is not None:
  317. raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.")
  318. del self.__dict__[name]
  319. if cells and name in cells:
  320. raise TypeError("Expected type is Cell, but got Parameter.")
  321. self.insert_param_to_cell(name, value)
  322. elif isinstance(value, ParameterTuple):
  323. if params is None:
  324. raise AttributeError("Can not assign params before Cell.__init__() call.")
  325. for item in value:
  326. self.insert_param_to_cell(item.name, item, check_name=False)
  327. if context.get_context("mode") == context.PYNATIVE_MODE:
  328. if name in self.__dict__:
  329. del self.__dict__[name]
  330. if name in params:
  331. del params[name]
  332. params_list[name] = value
  333. else:
  334. object.__setattr__(self, name, value)
  335. elif isinstance(value, Cell):
  336. if cells is None:
  337. raise AttributeError("Can not assign cells before Cell.__init__() call.")
  338. if name in self.__dict__:
  339. del self.__dict__[name]
  340. if params and name in params:
  341. raise TypeError("Expected type is Parameter, but got Cell.")
  342. if self._auto_prefix:
  343. value.update_parameters_name(name + '.')
  344. cells[name] = value
  345. elif params and name in params:
  346. if isinstance(value, Tensor) and self._params[name] is not None:
  347. self._params[name].set_data(value)
  348. elif value is not None:
  349. raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value)))
  350. else:
  351. self.insert_param_to_cell(name, None)
  352. elif cells and name in cells:
  353. if value is not None:
  354. raise TypeError("Expected type is cell, but got {}.".format(type(value)))
  355. self._cells[name] = None
  356. elif isinstance(value, Tensor):
  357. if context.get_context("mode") == context.PYNATIVE_MODE:
  358. if name in self.__dict__:
  359. del self.__dict__[name]
  360. tensor_list[name] = value
  361. else:
  362. object.__setattr__(self, name, value)
  363. else:
  364. if isinstance(value, Primitive):
  365. value.set_prim_instance_name(name)
  366. object.__setattr__(self, name, value)
  367. if name not in Cell.IGNORE_LIST:
  368. self._attr_synced = False
  369. def extend_repr(self):
  370. """
  371. Sets the extended representation of the Cell.
  372. To print customized extended information, re-implement this method in your own cells.
  373. """
  374. return ''
  375. def __str__(self):
  376. return self.__repr__()
  377. def __repr__(self):
  378. extra_str = self.extend_repr()
  379. info_str = self.__class__.__name__ + '<'
  380. if self._cells:
  381. sub_str = '\n'
  382. if extra_str:
  383. sub_str += '{}\n'.format(self.extend_repr())
  384. for key, value in self._cells.items():
  385. sub_str += '({}): {}\n'.format(key, repr(value))
  386. sub_str = sub_str.replace('\n', '\n ') + '>'
  387. info_str += sub_str
  388. else:
  389. info_str += extra_str + '>'
  390. return info_str
  391. def load_parameter_slice(self, params):
  392. """
  393. Replace parameters with sliced tensors by parallel strategies.
  394. Please refer to the usage in source code of `mindspore.common._Executor.compile`.
  395. Args:
  396. params (dict): The parameters dictionary used for initializing the data graph.
  397. """
  398. if params is None:
  399. params = self.parameters_dict()
  400. if isinstance(params, OrderedDict):
  401. for key in params:
  402. tensor = params[key].data
  403. if key not in self.parameter_layout_dict:
  404. logger.info("layout dict does not contain the key %s", key)
  405. continue
  406. if params[key].sliced:
  407. logger.debug("Param %s is already sliced.", key)
  408. continue
  409. layout = self.parameter_layout_dict[key]
  410. new_tensor = _load_tensor_by_layout(tensor, layout)
  411. params[key].set_data(new_tensor, True)
  412. else:
  413. raise TypeError('Parameters need OrderedDict type, but got {}'.
  414. format(type(params)))
  415. def _load_inputs(self, *inputs):
  416. """
  417. Slice inputs tensors by parallel strategies.
  418. Args:
  419. inputs (Function or Cell): inputs of construct method.
  420. """
  421. parallel_inputs_run = []
  422. # judge if *args exists in input
  423. if self.argspec[1] is not None:
  424. prefix = self.argspec[1]
  425. for i in range(len(inputs)):
  426. key = prefix + str(i)
  427. self._construct_inputs_names = self._construct_inputs_names + (key,)
  428. self._construct_inputs_num = self._construct_inputs_num + 1
  429. for i, tensor in enumerate(inputs):
  430. key = self._construct_inputs_names[i]
  431. # if input is not used, self.parameter_layout_dict may not contain the key
  432. if key not in self.parameter_layout_dict:
  433. logger.warning("layout dict does not contain the key %s", key)
  434. parallel_inputs_run.append(tensor)
  435. else:
  436. layout = self.parameter_layout_dict[key]
  437. new_tensor = _load_tensor_by_layout(tensor, layout)
  438. parallel_inputs_run.append(new_tensor)
  439. return tuple(parallel_inputs_run)
  440. def set_parallel_input_with_inputs(self, *inputs):
  441. """
  442. Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run`
  443. Args:
  444. inputs (tuple): inputs of construct method.
  445. """
  446. self._parallel_inputs_run = self._load_inputs(*inputs)
  447. def _get_construct_inputs_number_and_name(self):
  448. """Compute self._construct_inputs_names and self._construct_inputs_num"""
  449. from mindspore._extends.parse.parser import get_parse_method_of_class
  450. fn = get_parse_method_of_class(self)
  451. self.argspec = inspect.getfullargspec(fn)
  452. self._construct_inputs_num = fn.__code__.co_argcount
  453. self._construct_inputs_names = fn.__code__.co_varnames
  454. assert self._construct_inputs_num > 0
  455. assert self._construct_inputs_names[0] == 'self'
  456. assert self._construct_inputs_num - 1 <= len(self._construct_inputs_names)
  457. self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
  458. self._construct_inputs_num = self._construct_inputs_num - 1
  459. def compile(self, *inputs):
  460. """
  461. Compiles cell.
  462. Args:
  463. inputs (tuple): Input parameters.
  464. """
  465. _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
  466. def compile_and_run(self, *inputs):
  467. """
  468. Compiles and runs cell.
  469. Args:
  470. inputs (tuple): Input parameters.
  471. Returns:
  472. Object, the result of executing.
  473. """
  474. self._auto_parallel_compile_and_run = True
  475. self.compile(*inputs)
  476. if self._auto_parallel_mode:
  477. if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
  478. # get parallel inputs in sink mode, parallel inputs set in _executor.compile
  479. parallel_inputs_run = self._parallel_inputs_run
  480. else:
  481. parallel_inputs_run = inputs
  482. return _executor(self, *parallel_inputs_run, phase=self.phase)
  483. return _executor(self, *inputs, phase=self.phase)
  484. def auto_parallel_compile_and_run(self):
  485. return self._auto_parallel_compile_and_run
  486. def exec_checkpoint_graph(self):
  487. """Executes saving checkpoint graph operation."""
  488. _executor(self, phase='save')
  489. def insert_param_to_cell(self, param_name, param, check_name=True):
  490. """
  491. Adds a parameter to the current cell.
  492. Inserts a parameter with given name to the cell. Please refer to the usage in
  493. source code of `mindspore.nn.Cell.__setattr__`.
  494. Args:
  495. param_name (str): Name of the parameter.
  496. param (Parameter): Parameter to be inserted to the cell.
  497. check_name (bool): Determines whether the name input is compatible. Default: True.
  498. Raises:
  499. KeyError: If the name of parameter is null or contains dot.
  500. AttributeError: If user did not call init() first.
  501. TypeError: If the type of parameter is not Parameter.
  502. """
  503. if not param_name:
  504. raise KeyError("The name of parameter should not be null.")
  505. if check_name and '.' in param_name:
  506. raise KeyError("The name of parameter should not contain \".\"")
  507. if '_params' not in self.__dict__:
  508. raise AttributeError("You need call init() first.")
  509. if hasattr(self, param_name) and param_name not in self._params:
  510. raise KeyError("Duplicated parameter name '{}'.".format(param_name))
  511. if not isinstance(param, Parameter) and param is not None:
  512. raise TypeError("The type of parameter should be 'Parameter' if not None.")
  513. self._params[param_name] = param
  514. def cast_param(self, param):
  515. """
  516. Cast parameter according to auto mix precison level in pynative mode.
  517. Args:
  518. param (Parameter): The parameter to cast.
  519. """
  520. if hasattr(self, "_mindspore_flags"):
  521. if self._mindspore_flags.get('fp32'):
  522. param.set_cast_dtype(mstype.float32)
  523. elif self._mindspore_flags.get('fp16'):
  524. param.set_cast_dtype(mstype.float16)
  525. else:
  526. # retest dtype
  527. param.set_cast_dtype()
  528. return param
  529. def insert_child_to_cell(self, child_name, child):
  530. """
  531. Adds a child cell to the current cell.
  532. Inserts a subcell with a given name to the current cell.
  533. Args:
  534. child_name (str): Name of the child cell.
  535. child (Cell): The child cell to be inserted.
  536. Raises:
  537. KeyError: Child Cell's name is incorrect or duplicated with the other child name.
  538. TypeError: Child Cell's type is incorrect.
  539. """
  540. if not child_name or '.' in child_name:
  541. raise KeyError("Child cell name is incorrect.")
  542. if hasattr(self, child_name) and child_name not in self._cells:
  543. raise KeyError("Duplicate child name '{}'.".format(child_name))
  544. if not isinstance(child, Cell) and child is not None:
  545. raise TypeError("Child cell type is incorrect.")
  546. self._cells[child_name] = child
  547. def construct(self, *inputs, **kwargs):
  548. """
  549. Defines the computation to be performed.
  550. This method must be overridden by all subclasses.
  551. Note:
  552. The inputs of the top cell only allow Tensor.
  553. Other types (tuple, list, int etc.) are forbidden.
  554. Returns:
  555. Tensor, returns the computed result.
  556. """
  557. raise NotImplementedError
  558. def init_parameters_data(self, auto_parallel_mode=False):
  559. """
  560. Initialize all parameters and replace the original saved parameters in cell.
  561. Notes:
  562. trainable_params() and other similar interfaces may return different parameter instance after
  563. `init_parameters_data`, do not save these result.
  564. Args:
  565. auto_parallel_mode (bool): If running in auto_parallel_mode.
  566. Returns:
  567. Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
  568. """
  569. replace = dict()
  570. def _updata(param):
  571. if param in replace:
  572. return replace[param]
  573. layout = None
  574. set_sliced = False
  575. if auto_parallel_mode:
  576. set_sliced = True
  577. if param.name not in self.parameter_layout_dict:
  578. logger.debug("Layout dict does not contain the key %s.", param.name)
  579. else:
  580. layout = self.parameter_layout_dict[param.name]
  581. new_p = param.init_data(layout, set_sliced=set_sliced)
  582. replace[param] = new_p
  583. return new_p
  584. # replace all original usage.
  585. cells = self.cells_and_names()
  586. for _, cell in cells:
  587. params = cell._params.items()
  588. for param_name, param in params:
  589. cell._params[param_name] = _updata(param)
  590. cell_dict = cell.__dict__
  591. for key in cell_dict:
  592. if isinstance(cell_dict[key], ParameterTuple):
  593. param_tuple = cell_dict[key]
  594. new_param_tuple = []
  595. for param in param_tuple:
  596. new_param_tuple.append(_updata(param))
  597. cell.__dict__[key] = ParameterTuple(new_param_tuple)
  598. return replace
  599. def parameters_dict(self, recurse=True):
  600. """
  601. Gets parameters dictionary.
  602. Gets the parameters dictionary of this cell.
  603. Args:
  604. recurse (bool): Whether contains the parameters of subcells. Default: True.
  605. Returns:
  606. OrderedDict, return parameters dictionary.
  607. """
  608. param_dict = OrderedDict()
  609. for param in self.get_parameters(expand=recurse):
  610. param_dict[param.name] = param
  611. return param_dict
  612. def parameters_broadcast_dict(self, recurse=True):
  613. param_dict = OrderedDict()
  614. for param in self.get_parameters(expand=recurse):
  615. if param.layerwise_parallel is False:
  616. param_dict[param.name] = param
  617. if not param_dict:
  618. return None
  619. return param_dict
  620. def update_parameters_name(self, prefix='', recurse=True):
  621. """
  622. Updates the names of parameters with given prefix string.
  623. Adds the given prefix to the names of parameters.
  624. Args:
  625. prefix (str): The prefix string.
  626. recurse (bool): Whether contains the parameters of subcells. Default: True.
  627. """
  628. _check_str_by_regular(prefix)
  629. for name, param in self.parameters_and_names(expand=recurse):
  630. if prefix != '':
  631. param.is_init = False
  632. param.name = prefix + name
  633. def trainable_params(self, recurse=True):
  634. """
  635. Returns all trainable parameters.
  636. Returns a list of all trainable parmeters.
  637. Args:
  638. recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
  639. Returns:
  640. List, the list of trainable parameters.
  641. """
  642. return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
  643. def untrainable_params(self, recurse=True):
  644. """
  645. Returns all untrainable parameters.
  646. Returns a list of all untrainable parmeters.
  647. Args:
  648. recurse (bool): Whether contains the untrainable parameters of subcells. Default: True.
  649. Returns:
  650. List, the list of untrainable parameters.
  651. """
  652. return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
  653. def get_parameters(self, expand=True):
  654. """
  655. Returns an iterator over cell parameters.
  656. Yields parameters of this cell. If `expand` is True, yield parameters of this cell and all subcells.
  657. Args:
  658. expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
  659. that are direct members of this cell. Default: True.
  660. Examples:
  661. >>> net = Net()
  662. >>> for item in net.get_parameters():
  663. >>> print(item)
  664. """
  665. for _, param in self.parameters_and_names(expand=expand):
  666. yield param
  667. def check_names(self):
  668. names = set("")
  669. for value, param in self.parameters_and_names():
  670. if param.name in names:
  671. raise ValueError("The value of {} is {}, its name '{}' already exists.".
  672. format(value, param, param.name))
  673. names.add(param.name)
  674. def parameters_and_names(self, name_prefix='', expand=True):
  675. """
  676. Returns an iterator over cell parameters.
  677. Includes the parameter's name and itself.
  678. Args:
  679. name_prefix (str): Namespace. Default: ''.
  680. expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
  681. that are direct members of this cell. Default: True.
  682. Examples:
  683. >>> n = Net()
  684. >>> names = []
  685. >>> for m in n.parameters_and_names():
  686. >>> if m[0]:
  687. >>> names.append(m[0])
  688. """
  689. cells = []
  690. if expand:
  691. cells = self.cells_and_names(name_prefix=name_prefix)
  692. else:
  693. cells.append((name_prefix, self))
  694. params_set = set()
  695. for cell_name, cell in cells:
  696. params = cell._params.items()
  697. for par_name, par in params:
  698. if par.inited_param is not None:
  699. par = par.inited_param
  700. if par is not None and id(par) not in params_set:
  701. params_set.add(id(par))
  702. par_new_name = par_name
  703. if cell_name:
  704. par_new_name = cell_name + '.' + par_new_name
  705. yield par_new_name, par
  706. def cells_and_names(self, cells=None, name_prefix=''):
  707. """
  708. Returns an iterator over all cells in the network.
  709. Includes the cell's name and itself.
  710. Args:
  711. cells (str): Cells to iterate over. Default: None.
  712. name_prefix (str): Namespace. Default: ''.
  713. Examples:
  714. >>> n = Net()
  715. >>> names = []
  716. >>> for m in n.cells_and_names():
  717. >>> if m[0]:
  718. >>> names.append(m[0])
  719. """
  720. t_cells = cells if cells else set()
  721. if self in t_cells:
  722. return
  723. t_cells.add(self)
  724. yield name_prefix, self
  725. for name, cell in self._cells.items():
  726. if cell:
  727. cells_name_prefix = name
  728. if name_prefix:
  729. cells_name_prefix = name_prefix + '.' + cells_name_prefix
  730. for ele in cell.cells_and_names(t_cells, cells_name_prefix):
  731. yield ele
  732. def cells(self):
  733. """Returns an iterator over immediate cells."""
  734. return self.name_cells().values()
  735. def _set_scope(self, name):
  736. """Sets the name on the first time."""
  737. if self._scope is None:
  738. self._scope = name
  739. def _children_scope_recursive(self, parent_prefix='Default'):
  740. """Generates the scope of each layer of the network recursively."""
  741. reserve_class_name_in_scope = context.get_context("reserve_class_name_in_scope")
  742. for name, cell in self.name_cells().items():
  743. yield parent_prefix + "/" + name + (("-" + cell.__class__.__name__)
  744. if reserve_class_name_in_scope else ""), cell
  745. for name, cell in self.name_cells().items():
  746. for key, value in cell._children_scope_recursive(parent_prefix + "/" + name +
  747. (("-" + cell.__class__.__name__)
  748. if reserve_class_name_in_scope else "")):
  749. yield key, value
  750. def get_scope(self):
  751. """Returns the scope of a cell object in one network."""
  752. return self._scope
  753. def generate_scope(self):
  754. """Generate the scope for each cell object in the network."""
  755. for name, cell in self._children_scope_recursive():
  756. cell._set_scope(name)
  757. def name_cells(self):
  758. """
  759. Returns an iterator over all cells in the network.
  760. Include name of the cell and cell itself.
  761. """
  762. value_set = set()
  763. cells = OrderedDict()
  764. for name, cell in self._cells.items():
  765. if cell is not None and cell not in value_set:
  766. value_set.add(cell)
  767. cells[name] = cell
  768. return cells
  769. def add_flags(self, **flags):
  770. if not hasattr(self, "_mindspore_flags"):
  771. self._mindspore_flags = {}
  772. self._mindspore_flags.update({**flags})
  773. self.__dict__.update({**flags})
  774. return self
  775. def add_flags_recursive(self, **flags):
  776. self.add_flags(**flags)
  777. if hasattr(self, '_cell_init_args'):
  778. self._cell_init_args += str({**flags})
  779. for cell in self.cells():
  780. cell.add_flags_recursive(**flags)
  781. return self
  782. def get_flags(self):
  783. if not hasattr(self, "_mindspore_flags"):
  784. self._mindspore_flags = {}
  785. return self._mindspore_flags
  786. def to_float(self, dst_type):
  787. """
  788. Add cast on all inputs of cell and child cells to run with certain float type.
  789. If `dst_type is mindspore.dtype.float16`, all the inputs of Cell including input, Parameter, Tensor
  790. as const will be cast to float16. Please refer to the usage in source code of
  791. `mindspore.train.amp.build_train_network`.
  792. Note:
  793. Multiple calls will overwrite.
  794. Args:
  795. dst_type (:class:`mindspore.dtype`): Transfer Cell to Run with dst_type.
  796. dst_type can be `mindspore.dtype.float16` or `mindspore.dtype.float32`.
  797. Raises:
  798. ValueError: If dst_type is not float32 nor float16.
  799. """
  800. if dst_type not in (mstype.float16, mstype.float32):
  801. raise ValueError("dst_type should inside float32 or float16.")
  802. flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
  803. self.add_flags_recursive(**flags)
  804. return self
  805. def set_grad(self, requires_grad=True):
  806. """
  807. Sets the cell flag for gradient.
  808. Args:
  809. requires_grad (bool): Specifies if the net need to grad, if it is
  810. True, cell will construct backward network in pynative mode. Default: True.
  811. """
  812. self.requires_grad = requires_grad
  813. return self
  814. def set_train(self, mode=True):
  815. """
  816. Sets the cell to training mode.
  817. The cell itself and all children cells will be set to training mode.
  818. Args:
  819. mode (bool): Specifies whether the model is training. Default: True.
  820. """
  821. if mode is False:
  822. self._phase = 'predict'
  823. else:
  824. self._phase = 'train'
  825. self.add_flags_recursive(training=mode)
  826. return self
  827. def set_broadcast_flag(self, mode=True):
  828. """
  829. Set the cell to data_parallel mode.
  830. The cell can be accessed as an attribute using the given name.
  831. Args:
  832. mode (bool): Specifies whether the model is data_parallel. Default: True.
  833. """
  834. self.add_flags_recursive(broadcast_flag=mode)
  835. return self
  836. def set_auto_parallel(self):
  837. """
  838. Set the cell to auto parallel mode.
  839. Note:
  840. If a cell needs to use the auto parallel or semi auto parallel mode for training, evaluation or prediction,
  841. this interface needs to be called by the cell.
  842. """
  843. self._auto_parallel_mode = True
  844. self.add_flags(auto_parallel=True)
  845. self._get_construct_inputs_number_and_name()
  846. def _hook_construct(self, *inputs, **kwargs):
  847. """Hook construct method to replace original construct method when hook function enabled."""
  848. inputs = self._backward_hook(*inputs)
  849. inputs = self.construct(inputs)
  850. outputs = self._backward_hook(inputs)
  851. return outputs
  852. def register_backward_hook(self, fn):
  853. """
  854. Set the cell backward hook function. Note that this function is only supported in Pynative Mode.
  855. Note:
  856. fn must be defined as the following code. `cell_name` is the name of registered cell.
  857. `grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the
  858. next cell or primitve, which may be modified and returned.
  859. >>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None
  860. Args:
  861. fn (function): Specifies the hook function with grad as input.
  862. """
  863. self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
  864. self.enable_hook = True
  865. def set_param_ps(self, recurse=True, init_in_server=False):
  866. """
  867. Set whether the trainable parameter is updated by parameter server.
  868. Note:
  869. It only works when a running task is in the parameter server mode.
  870. Args:
  871. recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
  872. """
  873. params = self.trainable_params(recurse)
  874. for param in params:
  875. param.set_param_ps(init_in_server)
  876. class GraphKernel(Cell):
  877. """
  878. Base class for GraphKernel.
  879. A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
  880. enable_graph_kernel in context is set to True.
  881. Examples:
  882. >>> class Relu(GraphKernel):
  883. >>> def __init__(self):
  884. >>> super(Relu, self).__init__()
  885. >>> self.max = P.Maximum()
  886. >>>
  887. >>> def construct(self, x):
  888. >>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
  889. """
  890. def __init__(self, auto_prefix=True, pips=None):
  891. super(GraphKernel, self).__init__(auto_prefix, pips)
  892. class_name = self.__class__.__name__
  893. self.add_flags(graph_kernel=class_name)
  894. def construct(self):
  895. raise NotImplementedError