You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.cell.rst 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. mindspore.nn.Cell
  2. ==================
  3. .. py:class:: mindspore.nn.Cell(auto_prefix=True, flags=None)
  4. 所有神经网络的基类。
  5. 一个 `Cell` 可以是单一的神经网络单元,如 :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.ReLU, :class:`mindspore.nn.BatchNorm`等,也可以是组成网络的 `Cell` 的结合体。
  6. .. note::
  7. 一般情况下,自动微分 (AutoDiff) 算法会自动调用梯度函数,但是如果使用反向传播方法 (bprop method),梯度函数将会被反向传播方法代替。反向传播函数会接收一个包含损失对输出的梯度张量 `dout` 和一个包含前向传播结果的张量 `out` 。反向传播过程需要计算损失对输入的梯度,损失对参数变量的梯度目前暂不支持。反向传播函数必须包含自身参数。
  8. **参数:**
  9. - **auto_prefix** (`Cell`) – 递归地生成作用域。默认值:True。
  10. - **flags** (`dict`) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  11. **支持平台**:
  12. ``Ascend`` ``GPU`` ``CPU``
  13. **样例** :
  14. >>> import mindspore.nn as nn
  15. >>> import mindspore.ops as ops
  16. >>> class MyCell(nn.Cell):
  17. ... def __init__(self):
  18. ... super(MyCell, self).__init__()
  19. ... self.relu = ops.ReLU()
  20. ...
  21. ... def construct(self, x):
  22. ... return self.relu(x)
  23. .. py:method:: add_flags(**flags)
  24. 为Cell添加自定义属性。
  25. 在实例化Cell类时,如果入参flags不为空,会调用此方法。
  26. **参数:**
  27. **flags** (`dict`) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  28. .. py:method:: add_flags_recursive(**flags)
  29. 如果Cell含有多个子Cell,此方法会递归得给所有子Cell添加自定义属性。
  30. **参数:**
  31. **flags** (`dict`) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。
  32. .. py:method:: cast_inputs(inputs, dst_type)
  33. 将输入转换为指定类型。
  34. **参数:**
  35. - **inputs** (`tuple[Tensor]`) - 输入。
  36. - **dst_type** (`mindspore.dtype`) - 指定的数据类型。
  37. **返回:**
  38. tuple[Tensor]类型,转换类型后的结果。
  39. .. py:method:: cast_param(param)
  40. 在PyNative模式下,根据自动混合精度的精度设置转换Cell中参数的类型。
  41. 该接口目前在自动混合精度场景下使用。
  42. **参数:**
  43. **param** (`Parameter`) – Parameter类型,需要被转换类型的输入参数。
  44. **返回:**
  45. Parameter类型,转换类型后的参数。
  46. .. py:method:: cells()
  47. 返回当前Cell的子Cell的迭代器。
  48. **返回:**
  49. Iteration类型,Cell的子Cell。
  50. .. py:method:: cells_and_names(cells=None, name_prefix="")
  51. 递归地获取当前Cell及输入 `cells` 的所有子Cell的迭代器,包括Cell的名称及其本身。
  52. **参数:**
  53. - **cell** (`str`) – 需要进行迭代的Cell。默认值:None。
  54. - **name_prefix** (`str`) – 作用域。默认值:''。
  55. **返回:**
  56. Iteration类型,当前Cell及输入 `cells` 的所有子Cell和相对应的名称。
  57. **样例:**
  58. >>> n = Net()
  59. >>> names = []
  60. >>> for m in n.cells_and_names():
  61. ... if m[0]:
  62. ... names.append(m[0])
  63. .. py:method:: check_names()
  64. 检查Cell中的网络参数名称是否重复。
  65. .. py:method:: compile(*inputs)
  66. 编译Cell。
  67. **参数:**
  68. **inputs** (`tuple`) – Cell的输入。
  69. .. py:method:: compile_and_run(*inputs)
  70. 编译并运行Cell。
  71. **参数:**
  72. **inputs** (`tuple`) – Cell的输入。
  73. **返回:**
  74. Object类型,执行的结果。
  75. .. py:method:: construct(*inputs, **kwargs)
  76. 定义要执行的计算逻辑。所有子类都必须重写此方法。
  77. **返回:**
  78. Tensor类型,返回计算结果。
  79. .. py:method:: exec_checkpoint_graph()
  80. 保存checkpoint图。
  81. .. py:method:: extend_repr()
  82. 设置Cell的扩展表示形式。
  83. 若需要在print时输出个性化的扩展信息,请在您的网络中重新实现此方法。
  84. .. py:method:: generate_scope()
  85. 为网络中的每个Cell对象生成作用域。
  86. .. py:method:: get_flags()
  87. 获取该Cell的自定义属性。自定义属性通过 `add_flags` 方法添加。
  88. .. py:method:: get_func_graph_proto()
  89. 返回图的二进制原型。
  90. .. py:method:: get_parameters(expand=True)
  91. 返回一个该Cell中parameter的迭代器。
  92. **参数:**
  93. **expand** (`bool`) – 如果为True,则递归地获取当前Cell和所有子Cell的parameter。否则,只生成当前Cell的子Cell的parameter。默认值:True。
  94. **返回:**
  95. Iteration类型,Cell的parameter。
  96. **样例:**
  97. >>> n = Net()
  98. >>> parameters = []
  99. >>> for item in net.get_parameters():
  100. ... parameters.append(item)
  101. .. py:method:: get_scope()
  102. 返回Cell的作用域。
  103. **返回:**
  104. String类型,网络的作用域。
  105. .. py:method:: insert_child_to_cell(child_name, child_cell)
  106. 将一个给定名称的子Cell添加到当前Cell。
  107. **参数:**
  108. - **child_name** (`str`) – 子Cell名称。
  109. - **child_cell** (`Cell`) – 要插入的子Cell。
  110. **异常:**
  111. - **KeyError** – 如果子Cell的名称不正确或与其他子Cell名称重复。
  112. - **TypeError** – 如果子Cell的类型不正确。
  113. .. py:method:: insert_param_to_cell(param_name, param, check_name=True)
  114. 向当前Cell添加参数。
  115. 将指定名称的参数插入Cell。目前在 `mindspore.nn.Cell.__setattr__` 中使用。
  116. **参数:**
  117. - **param_name** (`str`) – 参数名称。
  118. - **param** (`Parameter`) – 要插入到Cell的参数。
  119. - **check_name** (`bool`) – 是否对 `param_name` 中的"."进行检查。默认值:True。
  120. **异常:**
  121. - **KeyError** – 如果参数名称为空或包含"."。
  122. - **TypeError** – 如果参数的类型不是Parameter。
  123. .. py:method:: name_cells()
  124. 递归地获取一个Cell中所有子Cell的迭代器。
  125. 包括Cell名称和Cell本身。
  126. **返回:**
  127. Dict[String, Cell],Cell中的所有子Cell及其名称。
  128. .. py:method:: parameters_and_names(name_prefix='', expand=True)
  129. 返回Cell中parameter的迭代器。
  130. 包含参数名称和参数本身。
  131. **参数:**
  132. - **name_prefix** (str): 作用域。默认值: ''。
  133. - **expand** (bool): 如果为True,则递归地获取当前Cell和所有子Cell的参数及名称;如果为False,只生成当前Cell的子Cell的参数及名称。默认值:True.
  134. **返回:**
  135. 迭代器,Cell的名称和Cell本身。
  136. **样例:**
  137. >>> n = Net()
  138. >>> names = []
  139. >>> for m in n.parameters_and_names():
  140. ... if m[0]:
  141. ... names.append(m[0])
  142. .. py:method:: param_prefix
  143. :property:
  144. 当前Cell的子Cell的参数名前缀。
  145. .. py:method:: parameters_dict(recurse=True)
  146. 获取此Cell的parameter字典。
  147. **参数:**
  148. **recurse** (`bool`) – 是否递归得包含所有子Cell的parameter。默认值:True。
  149. **返回:**
  150. OrderedDict类型,返回参数字典。
  151. .. py:method:: remove_redundant_parameters()
  152. 删除冗余参数。
  153. 这个接口通常不需要显式调用。
  154. .. py:method:: set_comm_fusion(fusion_type, recurse=True)
  155. 为Cell中的参数设置融合类型。请参考 :class:`mindspore.Parameter.comm_fusion` 的描述。
  156. .. note:: 当函数被多次调用时,此属性值将被重写。
  157. **参数:**
  158. - **fusion_type** (`int`) – Parameter的 `comm_fusion` 属性的设置值。
  159. - **recurse** (`bool`) – 是否递归地设置子Cell的可训练参数。默认值:True。
  160. .. py:method:: set_grad(requires_grad=True)
  161. Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为True,则在执行正向网络时,将生成需要计算梯度的反向网络。
  162. **参数:**
  163. **requires_grad** (`bool`) – 指定网络是否需要梯度,如果为True,PyNative模式下Cell将构建反向网络。默认值:True。
  164. **返回:**
  165. Cell类型,Cell本身。
  166. .. py:method:: set_train(mode=True)
  167. 将Cell设置为训练模式。
  168. 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。
  169. **参数:**
  170. **mode** (`bool`) – 指定模型是否为训练模式。默认值:True。
  171. **返回:**
  172. Cell类型,Cell本身。
  173. .. py:method:: to_float(dst_type)
  174. 在Cell和所有子Cell的输入上添加类型转换,以使用特定的浮点类型运行。
  175. 如果 `dst_type` 是 `mindspore.dtype.float16` ,Cell的所有输入(包括作为常量的input, Parameter, Tensor)都会被转换为float16。请参考 `mindspore.build_train_network` 的源代码中的用法。
  176. .. note:: 多次调用将产生覆盖。
  177. **参数:**
  178. **dst_type** (`mindspore.dtype`) – Cell转换为 `dst_type` 类型运行。 `dst_type` 可以是 `mindspore.dtype.float16` 或者 `mindspore.dtype.float32` 。
  179. **返回:**
  180. Cell类型,Cell本身。
  181. **异常:**
  182. **ValueError** – 如果 `dst_type` 不是 `mindspore.dtype.float32` ,也不是 `mindspore.dtype.float16`。
  183. .. py:method:: trainable_params(recurse=True)
  184. 返回Cell的可训练参数。
  185. 返回一个可训练参数的列表。
  186. **参数:**
  187. **recurse** (`bool`) – 是否递归地包含当前Cell的所有子Cell的可训练参数。默认值:True。
  188. **返回:**
  189. List类型,可训练参数列表。
  190. .. py:method:: untrainable_params(recurse=True)
  191. 返回Cell的不可训练参数。
  192. 返回一个不可训练参数的列表。
  193. **参数:**
  194. **recurse** (`bool`) – 是否递归地包含当前Cell的所有子Cell的不可训练参数。默认值:True。
  195. **返回:**
  196. List类型,不可训练参数列表。
  197. .. py:method:: update_cell_prefix()
  198. 递归地更新所有子Cell的 `param_prefix` 。
  199. 在调用此方法后,可以通过Cell的 `param_prefix` 属性获取该Cell的所有子Cell的名称前缀。
  200. .. py:method:: update_cell_type(cell_type)
  201. 量化感知训练网络场景下,更新当前Cell的类型。
  202. 此方法将Cell类型设置为 `cell_type` 。
  203. **参数:**
  204. **cell_type** (str) – 被更新的类型,`cell_type` 可以是"quant"或"second-order"。
  205. .. py:method:: update_parameters_name(prefix="", recurse=True)
  206. 给网络参数名称添加 `prefix` 前缀字符串。
  207. **参数:**
  208. - **prefix** (`str`) – 前缀字符串。默认值:''。
  209. - **recurse** (`bool`) – 是否递归地包含所有子Cell的参数。默认值:True。