You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.model.rst 3.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. mindspore.Model
  2. ================
  3. .. py:method:: infer_predict_layout(*predict_data)
  4. 在 `AUTO_PARALLEL` 或 `SEMI_AUTO_PARALLEL` 模式下为预测网络生成参数layout,数据可以是单个或多个张量。
  5. .. note:: 同一批次数据应放在一个张量中。
  6. **参数:**
  7. **predict_data** (`Tensor`) – 单个或多个张量的预测数据。
  8. **返回:**
  9. Dict,用于加载分布式checkpoint的参数layout字典。它总是作为 `load_distributed_checkpoint()` 函数的一个入参。
  10. **异常:**
  11. **RuntimeError** – 如果不是图模式(GRAPH_MODE)。
  12. **样例:**
  13. >>> # 该例子需要在多设备上运行。请参考mindpore.cn上的教程 > 分布式训练。
  14. >>> import numpy as np
  15. >>> import mindspore as ms
  16. >>> from mindspore import Model, context, Tensor
  17. >>> from mindspore.context import ParallelMode
  18. >>> from mindspore.communication import init
  19. >>>
  20. >>> context.set_context(mode=context.GRAPH_MODE)
  21. >>> init()
  22. >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  23. >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
  24. >>> model = Model(Net())
  25. >>> model.infer_predict_layout(input_data)
  26. .. py:method:: infer_train_layout(train_dataset, dataset_sink_mode=True, sink_size=-1)
  27. 在 `AUTO_PARALLEL` 或 `SEMI_AUTO_PARALLEL` 模式下为训练网络生成参数layout,当前只有数据下沉模式可支持使用。
  28. .. warning:: 这是一个实验性的原型,可能会被改变和/或删除。
  29. .. note:: 这是一个预编译函数。参数必须与model.train()函数相同。
  30. **参数:**
  31. - **train_dataset** (`Dataset`) – 一个训练数据集迭代器。如果没有损失函数(loss_fn),返回一个包含多个数据的元组(data1, data2, data3, ...)并传递给网络。否则,返回一个元组(data, label),数据和标签将被分别传递给网络和损失函数。
  32. - **dataset_sink_mode** (`bool`) – 决定是否以数据集下沉模式进行训练。默认值:True。配置项是PyNative模式或CPU时,训练模型流程使用的是数据不下沉(non-sink)模式。默认值:True。
  33. - **sink_size** (`int`) – 控制每次数据下沉的数据量,如果 `sink_size` =-1,则每一次epoch下沉完整数据集。如果 `sink_size` >0,则每一次epoch下沉数据量为 `sink_size` 的数据集。如果 `dataset_sink_mode` 为False,则设置 `sink_size` 为无效。默认值:-1。
  34. **返回:**
  35. Dict,用于加载分布式checkpoint的参数layout字典。
  36. **样例:**
  37. >>> # 该例子需要在多设备上运行。请参考mindpore.cn上的教程 > 分布式训练。
  38. >>> import numpy as np
  39. >>> import mindspore as ms
  40. >>> from mindspore import Model, context, Tensor, nn, FixedLossScaleManager
  41. >>> from mindspore.context import ParallelMode
  42. >>> from mindspore.communication import init
  43. >>>
  44. >>> context.set_context(mode=context.GRAPH_MODE)
  45. >>> init()
  46. >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  47. >>>
  48. >>> # 如何构建数据集,请参考官方网站上关于【数据集】的章节。
  49. >>> dataset = create_custom_dataset()
  50. >>> net = Net()
  51. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  52. >>> loss_scale_manager = FixedLossScaleManager()
  53. >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  54. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
  55. >>> layout_dict = model.infer_train_layout(dataset)