| @@ -18,7 +18,7 @@ Wrap cells for networks. | |||||
| Use the Wrapper to combine the loss or build the training steps. | Use the Wrapper to combine the loss or build the training steps. | ||||
| """ | """ | ||||
| from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \ | from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \ | ||||
| ParameterUpdate, GetNextSingleOp | |||||
| ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple | |||||
| from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell | from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell | ||||
| from .grad_reducer import DistributedGradReducer | from .grad_reducer import DistributedGradReducer | ||||
| @@ -33,5 +33,6 @@ __all__ = [ | |||||
| "DistributedGradReducer", | "DistributedGradReducer", | ||||
| "ParameterUpdate", | "ParameterUpdate", | ||||
| "DynamicLossScaleUpdateCell", | "DynamicLossScaleUpdateCell", | ||||
| "FixedLossScaleUpdateCell" | |||||
| "FixedLossScaleUpdateCell", | |||||
| "VirtualDatasetCellTriple" | |||||
| ] | ] | ||||
| @@ -278,6 +278,36 @@ class _VirtualDatasetCell(Cell): | |||||
| return self._backbone(data_, label_) | return self._backbone(data_, label_) | ||||
| class VirtualDatasetCellTriple(Cell): | |||||
| """ | |||||
| Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. | |||||
| VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs | |||||
| of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted | |||||
| dynamically during the graph compile process. | |||||
| Note: | |||||
| Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in | |||||
| _VirtualDatasetCell. | |||||
| Args: | |||||
| backbone (Cell): The target network to wrap. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> net = VirtualDatasetCellTriple(net) | |||||
| """ | |||||
| def __init__(self, backbone): | |||||
| super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False) | |||||
| self._backbone = backbone | |||||
| self._virtual_dataset = _VirtualDataset() | |||||
| def construct(self, a, b, c): | |||||
| a_, b_, c_ = self._virtual_dataset(a, b, c) | |||||
| return self._backbone(a_, b_, c_) | |||||
| class WithEvalCell(Cell): | class WithEvalCell(Cell): | ||||
| r""" | r""" | ||||
| Cell that returns loss, output and label for evaluation. | Cell that returns loss, output and label for evaluation. | ||||
| @@ -21,6 +21,7 @@ import mindspore as ms | |||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops.operations.comm_ops import _VirtualDataset | from mindspore.ops.operations.comm_ops import _VirtualDataset | ||||
| from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | |||||
| from mindspore import context | from mindspore import context | ||||
| @@ -73,6 +74,29 @@ def test_virtual_dataset_3_input(): | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| _executor.compile(net, x, y, b) | _executor.compile(net, x, y, b) | ||||
| def test_virtualdataset_cell_3_inputs(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy0, strategy1, strategy2, strategy3): | |||||
| super().__init__() | |||||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||||
| self.matmul2 = P.MatMul().set_strategy(strategy2) | |||||
| self.gelu = P.Gelu().set_strategy(strategy3) | |||||
| def construct(self, x, y, b): | |||||
| out = self.gelu(self.matmul1(x, y)) | |||||
| out = self.matmul2(out, b) | |||||
| return out | |||||
| net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None)))) | |||||
| context.set_context(save_graphs=True) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([64, 2048]), dtype=ms.float32) | |||||
| net.set_auto_parallel() | |||||
| _executor.compile(net, x, y, b) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_virtual_dataset_3_input() | test_virtual_dataset_3_input() | ||||