You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.Model.rst 15 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. mindspore.Model
  2. ================
  3. .. py:class:: mindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", acc_level="O0", **kwargs)
  4. 模型训练或推理的高阶接口。 `Model` 会根据用户传入的参数封装可训练或推理的实例。
  5. **参数:**
  6. - **network** (Cell) – 用于训练或推理的神经网络。
  7. - **loss_fn** (Cell) - 损失函数。如果 `loss_fn` 为None,`network` 中需要进行损失函数计算,必要时也需要进行并行计算。默认值:None。
  8. - **optimizer** (Cell) - 用于更新网络权重的优化器。如果 `optimizer` 为None, `network` 中需要进行反向传播和网络权重更新。默认值:None。
  9. - **metrics** (Union[dict, set]) - 用于模型评估的一组评价函数。例如:{'accuracy', 'recall'}。默认值:None。
  10. - **eval_network** (Cell) - 用于评估的神经网络。未定义情况下,`Model` 会使用 `network` 和 `loss_fn` 封装一个 `eval_network` 。默认值:None。
  11. - **eval_indexes** (list) - 在定义 `eval_network` 的情况下使用。如果 `eval_indexes` 为默认值None,`Model` 会将 `eval_network` 的所有输出传给 `metrics` 。如果配置 `eval_indexes` ,必须包含三个元素,分别为损失值、预测值和标签在 `eval_network` 输出中的位置,此时,损失值将传给损失评价函数,预测值和标签将传给其他评价函数。推荐使用评价函数的 `mindspore.nn.Metric.set_indexes` 代替 `eval_indexes` 。默认值:None。
  12. - **amp_level** (str) - `mindspore.build_train_network` 的可选参数 `level`,`level` 为混合精度等级,该参数支持["O0", "O2", "O3", "auto"]。默认值:"O0"。
  13. - O0: 无变化。
  14. - O2: 将网络精度转为float16,batchnorm保持float32精度,使用动态调整梯度放大系数(loss scale)的策略。
  15. - O3: 将网络精度(包括batchnorm)转为float16,不使用梯度调整策略。
  16. - auto: 为不同处理器设置专家推荐的混合精度等级,如在GPU上设为O2,在Ascend上设为O3。该设置方式可能在部分场景下不适用,建议用户根据具体的网络模型自定义设置 `amp_level` 。
  17. 在GPU上建议使用O2,在Ascend上建议使用O3。
  18. 通过`kwargs`设置`keep_batchnorm_fp32`,可修改batchnorm策略,`keep_batchnorm_fp32`必须为bool类型;通过`kwargs`设置`loss_scale_manager`可修改梯度放大策略,`loss_scale_manager`必须为:class:`mindspore.LossScaleManager`的子类,
  19. 关于 `amp_level` 详见 `mindpore.build_train_network`。
  20. **样例:**
  21. >>> from mindspore import Model, nn
  22. >>>
  23. >>> class Net(nn.Cell):
  24. ... def __init__(self, num_class=10, num_channel=1):
  25. ... super(Net, self).__init__()
  26. ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
  27. ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  28. ... self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
  29. ... self.fc2 = nn.Dense(120, 84, weight_init='ones')
  30. ... self.fc3 = nn.Dense(84, num_class, weight_init='ones')
  31. ... self.relu = nn.ReLU()
  32. ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  33. ... self.flatten = nn.Flatten()
  34. ...
  35. ... def construct(self, x):
  36. ... x = self.max_pool2d(self.relu(self.conv1(x)))
  37. ... x = self.max_pool2d(self.relu(self.conv2(x)))
  38. ... x = self.flatten(x)
  39. ... x = self.relu(self.fc1(x))
  40. ... x = self.relu(self.fc2(x))
  41. ... x = self.fc3(x)
  42. ... return x
  43. >>>
  44. >>> net = Net()
  45. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  46. >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  47. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  48. >>> # 如何构建数据集,请参考官方网站的数据集相关章节
  49. >>> dataset = create_custom_dataset()
  50. >>> model.train(2, dataset)
  51. .. py:method:: build(train_dataset=None, valid_dataset=None, sink_size=-1)
  52. 数据下沉模式下构建计算图和数据图。
  53. .. warning::这是一个实验性接口,后续可能删除或修改。
  54. .. note:: 如果预先调用该接口构建计算图,那么 `Model.train` 会直接执行计算图。预构建计算图目前仅支持GRAPH_MOD模式和Ascend处理器,仅支持数据下沉模式。
  55. **参数:**
  56. - **train_dataset** (Dataset) – 一个训练集迭代器。如果定义了 `train_dataset` ,将会构建训练计算图。默认值:None。
  57. - **valid_dataset** (Dataset) - 一个验证集迭代器。如果定义了 `valid_dataset` ,将会构建验证计算图,此时 `Model` 中的 `metrics` 不能为None。默认值:None。
  58. - **sink_size** (int) - 控制每次数据下沉的数据量。默认值:-1。
  59. - **epoch** (int) - 控制训练轮次。默认值:1。
  60. **样例:**
  61. >>> from mindspore import Model, nn, FixedLossScaleManager
  62. >>>
  63. >>> # 如何构建数据集,请参考官方网站的数据集相关章节
  64. >>> dataset = create_custom_dataset()
  65. >>> net = Net()
  66. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  67. >>> loss_scale_manager = FixedLossScaleManager()
  68. >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  69. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  70. >>> model.build(dataset, epoch=2)
  71. >>> model.train(2, dataset)
  72. >>> model.train(2, dataset)
  73. .. py:method:: eval(valid_dataset, callbacks=None, dataset_sink_mode=True)
  74. 模型评估接口。
  75. 使用PyNative模式或CPU处理器时,模型评估流程将以非下沉模式执行。
  76. .. note::
  77. 如果 `dataset_sink_mode` 配置为True,数据将被送到处理器中。如果处理器是Ascend,数据特征将被逐一传输,每次数据传输的限制是256M。如果 `dataset_sink_mode` 配置为True,数据集仅能在当前模型中使用,而不能被其他模型使用。该接口会构建并执行计算图,如果使用前先执行了 `Model.build` ,那么它会直接执行计算图而不构建。
  78. **参数:**
  79. - **valid_dataset** (Dataset) – 评估模型的数据集。
  80. - **callbacks** (Optional[list(Callback), Callback]) - 评估过程中需要执行的回调对象或回调对象列表。默认值:None。
  81. - **dataset_sink_mode** (bool) - 是否通过数据通道获取数据。默认值:True。
  82. **返回:**
  83. Dict,键是用户定义的评价指标名称,值是以推理模式运行的评估结果。
  84. **样例:**
  85. >>> from mindspore import Model, nn
  86. >>>
  87. >>> # 如何构建数据集,请参考官方网站的数据集相关章节
  88. >>> dataset = create_custom_dataset()
  89. >>> net = Net()
  90. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  91. >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
  92. >>> acc = model.eval(dataset, dataset_sink_mode=False)
  93. .. py:method:: eval_network
  94. :property:
  95. 获取该模型的评价网络。
  96. **返回:**
  97. 评估网络实例。
  98. .. py:method:: infer_predict_layout(*predict_data)
  99. 在 `AUTO_PARALLEL` 或 `SEMI_AUTO_PARALLEL` 模式下为预测网络生成参数layout,数据可以是单个或多个张量。
  100. .. note:: 同一批次数据应放在一个张量中。
  101. **参数:**
  102. - **predict_data** (Tensor) – 单个或多个张量的预测数据。
  103. **返回:**
  104. Dict,用于加载分布式checkpoint的参数layout字典。它总是作为 `load_distributed_checkpoint()` 函数的一个入参。
  105. **异常:**
  106. - **RuntimeError** – 如果不是图模式(GRAPH_MODE)。
  107. **样例:**
  108. >>> # 该例子需要在多设备上运行。请参考mindpore.cn上的教程 > 分布式训练。
  109. >>> import numpy as np
  110. >>> import mindspore as ms
  111. >>> from mindspore import Model, context, Tensor
  112. >>> from mindspore.context import ParallelMode
  113. >>> from mindspore.communication import init
  114. >>>
  115. >>> context.set_context(mode=context.GRAPH_MODE)
  116. >>> init()
  117. >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  118. >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
  119. >>> model = Model(Net())
  120. >>> model.infer_predict_layout(input_data)
  121. .. py:method:: infer_train_layout(train_dataset, dataset_sink_mode=True, sink_size=-1)
  122. 在 `AUTO_PARALLEL` 或 `SEMI_AUTO_PARALLEL` 模式下为训练网络生成参数layout,当前只有数据下沉模式可支持使用。
  123. .. warning:: 这是一个实验性的原型,可能会被改变和/或删除。
  124. .. note:: 这是一个预编译函数。参数必须与Model.train()函数相同。
  125. **参数:**
  126. - **train_dataset** (Dataset) – 一个训练数据集迭代器。如果没有损失函数(loss_fn),返回一个包含多个数据的元组(data1, data2, data3, ...)并传递给网络。否则,返回一个元组(data, label),数据和标签将被分别传递给网络和损失函数。
  127. - **dataset_sink_mode** (bool) – 决定是否以数据集下沉模式进行训练。默认值:True。配置项是PyNative模式或CPU时,训练模型流程使用的是数据不下沉(non-sink)模式。默认值:True。
  128. - **sink_size** (int) – 控制每次数据下沉的数据量,如果 `sink_size` =-1,则每一次epoch下沉完整数据集。如果 `sink_size` >0,则每一次epoch下沉数据量为 `sink_size` 的数据集。如果 `dataset_sink_mode` 为False,则设置 `sink_size` 为无效。默认值:-1。
  129. **返回:**
  130. Dict,用于加载分布式checkpoint的参数layout字典。
  131. **样例:**
  132. >>> # 该例子需要在多设备上运行。请参考mindpore.cn上的教程 > 分布式训练。
  133. >>> import numpy as np
  134. >>> import mindspore as ms
  135. >>> from mindspore import Model, context, Tensor, nn, FixedLossScaleManager
  136. >>> from mindspore.context import ParallelMode
  137. >>> from mindspore.communication import init
  138. >>>
  139. >>> context.set_context(mode=context.GRAPH_MODE)
  140. >>> init()
  141. >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  142. >>>
  143. >>> # 如何构建数据集,请参考官方网站上关于[数据集]的章节。
  144. >>> dataset = create_custom_dataset()
  145. >>> net = Net()
  146. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  147. >>> loss_scale_manager = FixedLossScaleManager()
  148. >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  149. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  150. >>> layout_dict = model.infer_train_layout(dataset)
  151. .. py:method:: predict(*predict_data)
  152. 输入样本得到预测结果。
  153. **参数:**
  154. - **predict_data** (Tensor) – 预测样本,数据可以是单个张量、张量列表或张量元组。
  155. **返回:**
  156. 返回预测结果,类型是张量或数组。
  157. **样例:**
  158. >>> import mindspore as ms
  159. >>> from mindspore import Model, Tensor
  160. >>>
  161. >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
  162. >>> model = Model(Net())
  163. >>> result = model.predict(input_data)
  164. .. py:method:: predict_network
  165. :property:
  166. 获得该模型的预测网络。
  167. **返回:**
  168. 预测网络实例。
  169. .. py:method:: train(epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1)
  170. 模型训练接口。
  171. 使用PYNATIVE_MODE模式或CPU处理器时,模型训练流程将以非下沉模式执行。
  172. .. note::
  173. 如果 `dataset_sink_mode` 配置为True,数据将被送到处理器中。如果处理器是Ascend,数据特征将被逐一传输,每次数据传输的限制是256M。如果 `dataset_sink_mode` 配置为True,仅在每个epoch结束时调用Callback实例的step_end方法。如果 `dataset_sink_mode` 配置为True,数据集仅能在当前模型中使用,而不能被其他模型使用。如果 `sink_size` 大于零,每次epoch可以无限次遍历数据集,直到遍历数据量等于 `sink_size` 为止。然后下次epoch是从上一次遍历的最后位置继续开始遍历。该接口会构建并执行计算图,如果使用前先执行了 `Model.build` ,那么它会直接执行计算图而不构建。
  174. **参数:**
  175. - **epoch** (int) – 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时,则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。
  176. - **train_dataset** (Dataset) – 一个训练数据集迭代器。如果定义了 `loss_fn` ,则数据和标签会被分别传给 `network` 和 `loss_fn` ,此时数据集需要返回一个元组(data, label)。如果数据集中有多个数据或者标签,可以设置 `loss_fn` 为None,并在 `network` 中实现损失函数计算,此时数据集返回的所有数据组成的元组(data1, data2, data3, ...)会传给 `network` 。
  177. - **callback** (Optional[list[Callback], Callback]) – 训练过程中需要执行的回调对象或者回调对象列表。默认值:None。
  178. - **dataset_sink_mode** (bool) – 是否通过数据通道获取数据。使用PYNATIVE_MODE模式或CPU处理器时,模型训练流程将以非下沉模式执行。默认值:True。
  179. - **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。
  180. **样例:**
  181. >>> from mindspore import Model, nn, FixedLossScaleManager
  182. >>>
  183. >>> # 如何构建数据集,请参考官方网站的数据集相关章节
  184. >>> dataset = create_custom_dataset()
  185. >>> net = Net()
  186. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  187. >>> loss_scale_manager = FixedLossScaleManager()
  188. >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  189. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  190. >>> model.train(2, dataset)
  191. .. py:method:: train_network
  192. :property:
  193. 获得该模型的训练网络。
  194. **返回:**
  195. 预测网络实例。