|
|
|
@@ -455,17 +455,41 @@ class Cell(Cell_): |
|
|
|
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively. |
|
|
|
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of |
|
|
|
this input/output and None represents data_parallel. |
|
|
|
|
|
|
|
Note: |
|
|
|
Only effective in PYNATIVE_MODE and auto_parallel_context in either ParallelMode.AUTO_PARALLEL and |
|
|
|
search_mode = sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL. |
|
|
|
Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL and |
|
|
|
set search_mode in auto_parallel_context as sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple |
|
|
|
defines the layout of the corresponding input and None represents a data parallel strategy. |
|
|
|
out_axes (tuple): Define the layout of outputs similar with in_axes. |
|
|
|
device (string): Select a certain device target. It is not in use right now. |
|
|
|
Support ["CPU", "GPU", "Ascend"]. Default: "Ascend". |
|
|
|
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation |
|
|
|
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in |
|
|
|
use right now. Support ["0", "1", "2"]. Default: "0". |
|
|
|
|
|
|
|
Returns: |
|
|
|
Cell, the cell itself. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.ops import functional as F |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> |
|
|
|
>>> class Block(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> self.dense1 = nn.Dense(10, 10) |
|
|
|
>>> self.relu = nn.ReLU() |
|
|
|
>>> self.dense2 = nn.Dense2(10, 10) |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> x = self.relu(self.dense2(self.relu(self.dense1(x)))) |
|
|
|
>>> return x |
|
|
|
>>> |
|
|
|
>>> class example(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> self.block1 = Block() |
|
|
|
>>> self.block2 = Block() |
|
|
|
>>> self.block2.shard(in_axes=(None, (2, 1)), out_axes=(None,)) |
|
|
|
>>> # self.parallel_block = F.shard(self.block2, in_axes=(None, (2, 1)), out_axes=(None,)) |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> x = self.block1(x) |
|
|
|
>>> x = self.block2(x) |
|
|
|
|