You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Cell.rst 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. mindspore.nn.Cell
  2. ==================
  3. .. py:class:: mindspore.nn.Cell(auto_prefix=True, flags=None)
  4. MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。
  5. `mindspore.nn` 中神经网络层也是Cell的子类,如 :class:`mindspore.nn.Conv2d`、:class:`mindspore.nn.ReLU`、 :class:`mindspore.nn.BatchNorm` 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。
  6. **参数:**
  7. - **auto_prefix** (bool) – 是否自动为Cell及其子Cell生成NameSpace。`auto_prefix` 的设置影响网络参数的命名,如果设置为True,则自动给网络参数的名称添加前缀,否则不添加前缀。默认值:True。
  8. - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  9. .. py:method:: add_flags(**flags)
  10. 为Cell添加自定义属性。
  11. 在实例化Cell类时,如果入参flags不为空,会调用此方法。
  12. **参数:**
  13. - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  14. .. py:method:: add_flags_recursive(**flags)
  15. 如果Cell含有多个子Cell,此方法会递归得给所有子Cell添加自定义属性。
  16. **参数:**
  17. - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  18. .. py:method:: auto_parallel_compile_and_run()
  19. 是否在‘AUTO_PARALLEL’或‘SEMI_AUTO_PARALLEL’模式下执行编译流程。
  20. **返回:**
  21. bool, `_auto_parallel_compile_and_run` 的值。
  22. .. py:method:: bprop_debug
  23. :property:
  24. 在图模式下使用,用于标识是否使用自定义的反向传播函数。
  25. .. py:method:: cast_inputs(inputs, dst_type)
  26. 将输入转换为指定类型。
  27. **参数:**
  28. - **inputs** (tuple[Tensor]) - 输入。
  29. - **dst_type** (mindspore.dtype) - 指定的数据类型。
  30. **返回:**
  31. tuple[Tensor]类型,转换类型后的结果。
  32. .. py:method:: cast_param(param)
  33. 在PyNative模式下,根据自动混合精度的精度设置转换Cell中参数的类型。
  34. 该接口目前在自动混合精度场景下使用。
  35. **参数:**
  36. - **param** (Parameter) – 需要被转换类型的输入参数。
  37. **返回:**
  38. Parameter类型,转换类型后的参数。
  39. .. py:method:: cells()
  40. 返回当前Cell的子Cell的迭代器。
  41. **返回:**
  42. Iteration类型,Cell的子Cell。
  43. .. py:method:: cells_and_names(cells=None, name_prefix="")
  44. 递归地获取当前Cell及输入 `cells` 的所有子Cell的迭代器,包括Cell的名称及其本身。
  45. **参数:**
  46. - **cells** (str) – 需要进行迭代的Cell。默认值:None。
  47. - **name_prefix** (str) – 作用域。默认值:''。
  48. **返回:**
  49. Iteration类型,当前Cell及输入 `cells` 的所有子Cell和相对应的名称。
  50. .. py:method:: check_names()
  51. 检查Cell中的网络参数名称是否重复。
  52. .. py:method:: set_inputs(*inputs)
  53. 设置编译计算图所需的输入,输入需与实例中定义的输入一致。
  54. **参数:**
  55. - **inputs** (tuple) – Cell的输入。
  56. .. note::
  57. 这是一个实验接口,可能会被更改或者删除。
  58. .. py:method:: compile(*inputs)
  59. 编译Cell为计算图,输入需与construct中定义的输入一致。
  60. **参数:**
  61. - **inputs** (tuple) – Cell的输入。
  62. .. py:method:: compile_and_run(*inputs)
  63. 编译并运行Cell,输入需与construct中定义的输入一致。
  64. .. note::
  65. 不推荐使用该函数,建议直接调用Cell实例。
  66. **参数:**
  67. - **inputs** (tuple) – Cell的输入。
  68. **返回:**
  69. Object类型,执行的结果。
  70. .. py:method:: construct(*inputs, **kwargs)
  71. 定义要执行的计算逻辑。所有子类都必须重写此方法。
  72. .. note::
  73. 当前不支持inputs同时输入tuple类型和非tuple类型。
  74. **参数:**
  75. - **inputs** – 可变参数列表,默认值:()。
  76. - **kwargs** – 可变的关键字参数的字典,默认值:{}。
  77. **返回:**
  78. Tensor类型,返回计算结果。
  79. .. py:method:: exec_checkpoint_graph()
  80. 保存checkpoint图。
  81. .. py:method:: extend_repr()
  82. 在原有描述基础上扩展Cell的描述。
  83. 若需要在print时输出个性化的扩展信息,请在您的网络中重新实现此方法。
  84. .. py:method:: generate_scope()
  85. 为网络中的每个Cell对象生成NameSpace。
  86. .. py:method:: get_flags()
  87. 获取该Cell的自定义属性。自定义属性通过 `add_flags` 方法添加。
  88. .. py:method:: get_func_graph_proto()
  89. 返回图的二进制原型。
  90. .. py:method:: get_parameters(expand=True)
  91. 返回Cell中parameter的迭代器。
  92. **参数:**
  93. - **expand** (bool) – 如果为True,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的子Cell的parameter。默认值:True。
  94. **返回:**
  95. Iteration类型,Cell的parameter。
  96. .. py:method:: get_scope()
  97. 返回Cell的作用域。
  98. **返回:**
  99. String类型,网络的作用域。
  100. .. py:method:: get_inputs()
  101. 返回编译计算图所设置的输入。
  102. **返回:**
  103. Tuple类型,编译计算图所设置的输入。
  104. .. note::
  105. 这是一个实验接口,可能会被更改或者删除。
  106. .. py:method:: infer_param_pipeline_stage()
  107. 推导Cell中当前 `pipeline_stage` 的参数。
  108. .. note::
  109. - 如果某参数不属于任何已被设置 `pipeline_stage` 的Cell,此参数应使用 `add_pipeline_stage` 方法来添加它的 `pipeline_stage` 信息。
  110. - 如果某参数P被stageA和stageB两个不同stage的算子使用,那么参数P在使用 `infer_param_pipeline_stage` 之前,应使用 `P.add_pipeline_stage(stageA)` 和 `P.add_pipeline_stage(stageB)` 添加它的stage信息。
  111. **返回:**
  112. 属于当前 `pipeline_stage` 的参数。
  113. **异常:**
  114. - **RuntimeError** – 如果参数不属于任何stage。
  115. .. py:method:: init_parameters_data(auto_parallel_mode=False)
  116. 初始化并替换Cell中所有的parameter的值。
  117. .. note::
  118. 在调用 `init_parameters_data` 后,`trainable_params()` 或其他相似的接口可能返回不同的参数对象,不要保存这些结果。
  119. **参数:**
  120. - **auto_parallel_mode** (bool) – 是否在自动并行模式下执行。 默认值:False。
  121. **返回:**
  122. Dict[Parameter, Parameter], 返回一个原始参数和替换参数的字典。
  123. .. py:method:: insert_child_to_cell(child_name, child_cell)
  124. 将一个给定名称的子Cell添加到当前Cell。
  125. **参数:**
  126. - **child_name** (str) – 子Cell名称。
  127. - **child_cell** (Cell) – 要插入的子Cell。
  128. **异常:**
  129. - **KeyError** – 如果子Cell的名称不正确或与其他子Cell名称重复。
  130. - **TypeError** – 如果子Cell的类型不正确。
  131. .. py:method:: insert_param_to_cell(param_name, param, check_name_contain_dot=True)
  132. 向当前Cell添加参数。
  133. 将指定名称的参数添加到Cell中。目前在 `mindspore.nn.Cell.__setattr__` 中使用。
  134. **参数:**
  135. - **param_name** (str) – 参数名称。
  136. - **param** (Parameter) – 要插入到Cell的参数。
  137. - **check_name_contain_dot** (bool) – 是否对 `param_name` 中的"."进行检查。默认值:True。
  138. **异常:**
  139. - **KeyError** – 如果参数名称为空或包含"."。
  140. - **TypeError** – 如果参数的类型不是Parameter。
  141. .. py:method:: load_parameter_slice(params)
  142. 根据并行策略获取Tensor分片并替换原始参数。
  143. 请参考 `mindspore.common._Executor.compile` 源代码中的用法。
  144. **参数:**
  145. **params** (dict) – 用于初始化数据图的参数字典。
  146. .. py:method:: name_cells()
  147. 递归地获取一个Cell中所有子Cell的迭代器。
  148. 包括Cell名称和Cell本身。
  149. **返回:**
  150. Dict[String, Cell],Cell中的所有子Cell及其名称。
  151. .. py:method:: param_prefix
  152. :property:
  153. 当前Cell的子Cell的参数名前缀。
  154. .. py:method:: parameter_layout_dict
  155. :property:
  156. `parameter_layout_dict` 表示一个参数的张量layout,这种张量layout是由分片策略和分布式算子信息推断出来的。
  157. .. py:method:: parameters_and_names(name_prefix='', expand=True)
  158. 返回Cell中parameter的迭代器。
  159. 包含参数名称和参数本身。
  160. **参数:**
  161. - **name_prefix** (str): 作用域。默认值: ''。
  162. - **expand** (bool): 如果为True,则递归地获取当前Cell和所有子Cell的参数及名称;如果为False,只生成当前Cell的子Cell的参数及名称。默认值:True。
  163. **返回:**
  164. 迭代器,Cell的名称和Cell本身。
  165. .. py:method:: parameters_broadcast_dict(recurse=True)
  166. 获取这个Cell的参数广播字典。
  167. **参数:**
  168. - **recurse** (bool): 是否包含子Cell的参数。 默认: True。
  169. **返回:**
  170. OrderedDict, 返回参数广播字典。
  171. .. py:method:: parameters_dict(recurse=True)
  172. 获取此Cell的parameter字典。
  173. **参数:**
  174. - **recurse** (bool) – 是否递归得包含所有子Cell的parameter。默认值:True。
  175. **返回:**
  176. OrderedDict类型,返回参数字典。
  177. .. py:method:: recompute(**kwargs)
  178. 设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
  179. .. note::
  180. - 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。
  181. - 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。
  182. - 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。
  183. - Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的,那么请使用算子的重计算API。
  184. - 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。
  185. - 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。
  186. **参数:**
  187. - **mp_comm_recompute** (bool) – 表示在自动并行或半自动并行模式下,指定Cell内部由模型并行引入的通信操作是否重计算。默认值:True。
  188. - **parallel_optimizer_comm_recompute** (bool) – 表示在自动并行或半自动并行模式下,指定Cell内部由优化器并行引入的AllGather通信是否重计算。默认值:False。
  189. .. py:method:: register_forward_pre_hook(hook_fn)
  190. 设置Cell对象的正向pre_hook函数。
  191. .. note::
  192. - `register_forward_pre_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
  193. - hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。用户可以在hook_fn中打印输入数据或者返回新的输入数据。
  194. - hook_fn返回新的输入数据或者None:hook_fn(cell_id, inputs) -> New inputs or None。
  195. - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` 。
  196. - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
  197. **参数:**
  198. - **hook_fn** (function) – 捕获Cell对象信息和正向输入数据的hook_fn函数。
  199. **返回:**
  200. `mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
  201. **异常:**
  202. - **TypeError** – 如果 `hook_fn` 不是Python函数。
  203. .. py:method:: register_forward_hook(hook_fn)
  204. 设置Cell对象的正向hook函数。
  205. .. note::
  206. - `register_forward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
  207. - hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `inputs` 是网络正向传播时Cell对象的输入数据。 `outputs` 是网络正向传播时Cell对象的输出数据。用户可以在hook_fn中打印数据或者返回新的输出数据。
  208. - hook_fn返回新的输出数据或者None:hook_fn(cell_id, inputs, outputs) -> New outputs or None。
  209. - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` 。
  210. - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
  211. **参数:**
  212. - **hook_fn** (function) – 捕获Cell对象信息和正向输入,输出数据的hook_fn函数。
  213. **返回:**
  214. `mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
  215. **异常:**
  216. - **TypeError** – 如果 `hook_fn` 不是Python函数。
  217. .. py:method:: register_backward_hook(hook_fn)
  218. 设置Cell对象的反向hook函数。
  219. .. note::
  220. - `register_backward_hook(hook_fn)` 在图模式下,或者在PyNative模式下使用 `ms_function` 功能时不起作用。
  221. - hook_fn必须有如下代码定义。 `cell_id` 是已注册Cell对象的信息,包括名称和ID。 `grad_input` 是反向传递给Cell对象的梯度。 `grad_output` 是Cell对象的反向输出梯度。用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。
  222. - hook_fn返回新的输出梯度或者None:hook_fn(cell_id, grad_input, grad_output) -> New grad_output or None。
  223. - 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` 。
  224. - PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook(hook_fn)` ,那么Cell对象每次运行都将增加一个 `hook_fn` 。
  225. **参数:**
  226. - **hook_fn** (function) – 捕获Cell对象信息和反向输入,输出梯度的hook_fn函数。
  227. **返回:**
  228. `mindspore.common.hook_handle.HookHandle` 类型,与 `hook_fn` 函数对应的 `handle` 对象。可通过调用 `handle.remove()` 来删除添加的 `hook_fn` 函数。
  229. **异常:**
  230. - **TypeError** – 如果 `hook_fn` 不是Python函数。
  231. .. py:method:: remove_redundant_parameters()
  232. 删除冗余参数。
  233. 这个接口通常不需要显式调用。
  234. .. py:method:: run_construct(cast_inputs, kwargs)
  235. 运行construct方法。
  236. .. note::
  237. - 该函数已经弃用,将会在未来版本中删除,不推荐使用此函数。
  238. **参数:**
  239. - **cast_inputs** (tuple) – Cell的输入。
  240. - **kwargs** (dict) – 关键字参数。
  241. **返回:**
  242. Cell的输出。
  243. .. py:method:: set_auto_parallel()
  244. 将Cell设置为自动并行模式。
  245. .. note:: 如果一个Cell需要使用自动并行或半自动并行模式来进行训练、评估或预测,则该Cell需要调用此接口。
  246. .. py:method:: set_comm_fusion(fusion_type, recurse=True)
  247. 为Cell中的参数设置融合类型。请参考 :class:`mindspore.Parameter.comm_fusion` 的描述。
  248. .. note:: 当函数被多次调用时,此属性值将被重写。
  249. **参数:**
  250. - **fusion_type** (int) – Parameter的 `comm_fusion` 属性的设置值。
  251. - **recurse** (bool) – 是否递归地设置子Cell的可训练参数。默认值:True。
  252. .. py:method:: set_broadcast_flag(mode=True)
  253. 设置该Cell的参数广播模式。
  254. **参数:**
  255. - **mode** (bool) – 指定当前模式是否进行参数广播。默认值:True。
  256. .. py:method:: set_data_parallel()
  257. 递归设置该Cell中的所有算子的并行策略为数据并行。
  258. .. note:: 仅在图模式、全自动并行(AUTO_PARALLEL)模式下生效。
  259. .. py:method:: shard(in_strategy, out_strategy, device="Ascend", level=0)
  260. 指定输入/输出Tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 in_strategy/out_strategy需要为元组类型,
  261. 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: `mindspore.ops.Primitive.shard` 的描述,也可以设置为None,会默认以数据并行执行。
  262. 其余算子的并行策略由输入输出指定的策略推导得到。
  263. .. note:: 需设置为PyNative模式,并且全自动并行(AUTO_PARALLEL),同时设置 `set_auto_parallel_context` 中的搜索模式(search mode)为"sharding_propagation"。
  264. **参数:**
  265. - **in_strategy** (tuple) – 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。
  266. - **out_strategy** (tuple) – 指定各输出的切分策略,用法同in_strategy。
  267. - **device** (string) - 指定执行设备,可以为["CPU", "GPU", "Ascend"]中任意一个,默认值:"Ascend"。目前尚未使能。
  268. - **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[0, 1, 2]中任意一个,默认值:0。目前仅支持
  269. 最大化计算通信比,其余模式尚未使能。
  270. **返回:**
  271. Cell类型,Cell本身。
  272. .. py:method:: auto_cast_inputs(inputs)
  273. 在混合精度下,自动对输入进行类型转换。
  274. **参数:**
  275. **inputs** (tuple) – construct方法的输入。
  276. **返回:**
  277. Tuple类型,经过类型转换后的输入。
  278. .. py:method:: set_grad(requires_grad=True)
  279. Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为True,则在执行正向网络时,将生成需要计算梯度的反向网络。
  280. **参数:**
  281. - **requires_grad** (bool) – 指定网络是否需要梯度,如果为True,PyNative模式下Cell将构建反向网络。默认值:True。
  282. **返回:**
  283. Cell类型,Cell本身。
  284. .. py:method:: set_parallel_input_with_inputs(*inputs)
  285. 通过并行策略对输入张量进行切分。
  286. **参数:**
  287. **inputs** (tuple) – construct方法的输入。
  288. .. py:method:: set_param_fl(push_to_server=False, pull_from_server=False, requires_aggr=True)
  289. 设置参数与服务器交互的方式。
  290. **参数:**
  291. - **push_to_server** (bool) – 是否将参数推送到服务器。默认值:False。
  292. - **pull_from_server** (bool) – 是否从服务器提取参数。默认值:False。
  293. - **requires_aggr** (bool) – 是否在服务器中聚合参数。默认值:True。
  294. .. py:method:: set_param_ps(recurse=True, init_in_server=False)
  295. 设置可训练参数是否由参数服务器更新,以及是否在服务器上初始化可训练参数。
  296. .. note:: 只在运行的任务处于参数服务器模式时有效。
  297. **参数:**
  298. - **recurse** (bool) – 是否设置子网络的可训练参数。默认值:True。
  299. - **init_in_server** (bool) – 是否在服务器上初始化由参数服务器更新的可训练参数。默认值:False。
  300. .. py:method:: set_train(mode=True)
  301. 将Cell设置为训练模式。
  302. 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。
  303. **参数:**
  304. - **mode** (bool) – 指定模型是否为训练模式。默认值:True。
  305. **返回:**
  306. Cell类型,Cell本身。
  307. .. py:method:: to_float(dst_type)
  308. 在Cell和所有子Cell的输入上添加类型转换,以使用特定的浮点类型运行。
  309. 如果 `dst_type` 是 `mindspore.dtype.float16` ,Cell的所有输入(包括作为常量的input, Parameter, Tensor)都会被转换为float16。请参考 `mindspore.build_train_network` 的源代码中的用法。
  310. .. note:: 多次调用将产生覆盖。
  311. **参数:**
  312. - **dst_type** (mindspore.dtype) – Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 或者 `mindspore.dtype.float32` 。
  313. **返回:**
  314. Cell类型,Cell本身。
  315. **异常:**
  316. - **ValueError** – 如果 `dst_type` 不是 `mindspore.dtype.float32` ,也不是 `mindspore.dtype.float16`。
  317. .. py:method:: set_boost(boost_type)
  318. 为了提升网络性能,可以配置boost内的算法让框架自动使能该算法来加速网络训练。
  319. 请确保 `boost_type` 所选择的算法在
  320. `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_ 算法库中。
  321. .. note:: 部分加速算法可能影响网络精度,请谨慎选择。
  322. **参数:**
  323. - **boost_type** (str) – 加速算法。
  324. **返回:**
  325. Cell类型,Cell本身。
  326. **异常:**
  327. - **ValueError** – 如果 `boost_type` 不在boost算法库内。
  328. .. py:method:: trainable_params(recurse=True)
  329. 返回Cell的可训练参数。
  330. 返回一个可训练参数的列表。
  331. **参数:**
  332. - **recurse** (bool) – 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值:True。
  333. **返回:**
  334. List类型,可训练参数列表。
  335. .. py:method:: untrainable_params(recurse=True)
  336. 返回Cell的不可训练参数。
  337. 返回一个不可训练参数的列表。
  338. **参数:**
  339. - **recurse** (bool) – 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值:True。
  340. **返回:**
  341. List类型,不可训练参数列表。
  342. .. py:method:: update_cell_prefix()
  343. 递归地更新所有子Cell的 `param_prefix` 。
  344. 在调用此方法后,可以通过Cell的 `param_prefix` 属性获取该Cell的所有子Cell的名称前缀。
  345. .. py:method:: update_cell_type(cell_type)
  346. 量化感知训练网络场景下,更新当前Cell的类型。
  347. 此方法将Cell类型设置为 `cell_type` 。
  348. **参数:**
  349. - **cell_type** (str) – 被更新的类型,`cell_type` 可以是"quant"或"second-order"。
  350. .. py:method:: update_parameters_name(prefix='', recurse=True)
  351. 给网络参数名称添加 `prefix` 前缀字符串。
  352. **参数:**
  353. - **prefix** (str) – 前缀字符串。默认值:''。
  354. - **recurse** (bool) – 是否递归地包含所有子Cell的参数。默认值:True。