Browse Source

doc complete

feature/build-system-rewrite
wangjun 4 years ago
parent
commit
379782c3fe
3 changed files with 46 additions and 9 deletions
  1. +2
    -1
      docs/api/api_python/mindspore.context.rst
  2. +16
    -4
      docs/api/api_python/nn/mindspore.nn.Cell.rst
  3. +28
    -4
      mindspore/python/mindspore/nn/cell.py

+ 2
- 1
docs/api/api_python/mindspore.context.rst View File

@@ -198,10 +198,11 @@ MindSpore context,用于配置当前执行环境,包括执行模式、执行
- semi_auto_parallel:半自动并行模式。
- auto_parallel:自动并行模式。

- **search_mode** (str) - 表示有两种策略搜索模式,分别是recursive_programming和dynamic_programming。默认值:dynamic_programming。
- **search_mode** (str) - 表示有三种策略搜索模式,分别是recursive_programming,dynamic_programming和sharding_propagation。默认值:dynamic_programming。

- recursive_programming:表示双递归搜索模式。
- dynamic_programming:表示动态规划搜索模式。
- sharding_propagation:表示从已配置算子的切分策略传播到所有算子。

- **auto_parallel_search_mode** (str) - search_modes参数的兼容接口。将在后续的版本中删除。
- **parameter_broadcast** (bool) - 表示在训练前是否广播参数。在训练之前,为了使所有设备的网络初始化参数值相同,请将设备0上的参数广播到其他设备。不同并行模式下的参数广播不同。在data_parallel模式下,除layerwise_parallel属性为True的参数外,所有参数都会被广播。在hybrid_parallel、semi_auto_parallel和auto_parallel模式下,分段参数不参与广播。默认值:False。


+ 16
- 4
docs/api/api_python/nn/mindspore.nn.Cell.rst View File

@@ -343,11 +343,23 @@
.. py:method:: shard(in_axes, out_axes, device="Ascend", level=0)
指定输入/输出tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个cell以图模式进行分布式执行。
in_axes/out_axes需要为元组类型,其中的每一个元素指定对应的输入/输出的tensor分布策略,其类型需要为元组,
可参考: `mindspore.ops.Primitive.shard` 的描述,也可以设置为None,会默认以数据并行执行。
指定输入/输出Tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 in_axes/out_axes需要为元组类型,
其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: `mindspore.ops.Primitive.shard` 的描述,也可以设置为None,会默认以数据并行执行。
其余算子的并行策略由输入输出指定的策略推导得到。
.. note:: 需设置为PyNative模式,并且全自动并行(AUTO_PARALLEL),同时设置`set_auto_parallel_context`中的搜索模式(search mode)为"sharding_propagation",或半自动并行(SEMI_AUTO_PARALLEL)。
.. note:: 需设置为Pyative模式,并且全自动并行(AUTO_PARALLEL),同时search mode为sharding_propagation,或半自动并行(SEMI_AUTO_PARALLEL)。
**参数:**
- **in_axes** (tuple) – 指定各输入的切分策略,输入元组的每个元素可以为元组或None,元组即具体指定输入每一维的切分策略,None则会默认以数据并行执行。
- **out_axes** (tuple) – 指定各输出的切分策略,用法同in_axes。
- **device** (string) - 指定执行设备,可以为["CPU", "GPU", "Ascend"]中任意一个,默认值:"Ascend"。目前尚未使能。
- **level** (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[0, 1, 2]中任意一个,默认值:0。目前仅支持
最大化计算通信比,其余模式尚未使能。
**返回:**
Cell类型,Cell本身。
.. py:method:: set_grad(requires_grad=True)


+ 28
- 4
mindspore/python/mindspore/nn/cell.py View File

@@ -455,17 +455,41 @@ class Cell(Cell_):
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively.
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of
this input/output and None represents data_parallel.

Note:
Only effective in PYNATIVE_MODE and auto_parallel_context in either ParallelMode.AUTO_PARALLEL and
search_mode = sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL.
Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL and
set search_mode in auto_parallel_context as sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL.

Inputs:
in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
defines the layout of the corresponding input and None represents a data parallel strategy.
out_axes (tuple): Define the layout of outputs similar with in_axes.
device (string): Select a certain device target. It is not in use right now.
Support ["CPU", "GPU", "Ascend"]. Default: "Ascend".
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
use right now. Support ["0", "1", "2"]. Default: "0".

Returns:
Cell, the cell itself.

Examples:
>>> from mindspore.ops import functional as F
>>> import mindspore.nn as nn
>>>
>>> class Block(nn.Cell):
>>> def __init__(self):
>>> self.dense1 = nn.Dense(10, 10)
>>> self.relu = nn.ReLU()
>>> self.dense2 = nn.Dense2(10, 10)
>>> def construct(self, x):
>>> x = self.relu(self.dense2(self.relu(self.dense1(x))))
>>> return x
>>>
>>> class example(nn.Cell):
>>> def __init__(self):
>>> self.block1 = Block()
>>> self.block2 = Block()
>>> self.block2.shard(in_axes=(None, (2, 1)), out_axes=(None,))
>>> # self.parallel_block = F.shard(self.block2, in_axes=(None, (2, 1)), out_axes=(None,))
>>> def construct(self, x):
>>> x = self.block1(x)
>>> x = self.block2(x)


Loading…
Cancel
Save