From aa523992000e3c36dff4a8e949fc12edb667933a Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Fri, 5 Mar 2021 19:00:43 +0800 Subject: [PATCH] Making the Tile operator to have more parallel strategies --- .../frontend/parallel/ops_info/select_info.h | 1 - .../frontend/parallel/ops_info/tile_info.cc | 40 +++++++++++++++---- tests/ut/python/parallel/test_select.py | 9 +++++ tests/ut/python/parallel/test_tile.py | 40 +++++++++++++++++-- 4 files changed, 79 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/select_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/select_info.h index 9683812782..6fcf219f5c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/select_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/select_info.h @@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo { protected: Status GetAttrs() override { return SUCCESS; } Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc index 9ef8d6eaac..c7c8d1cc34 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc @@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() { return SUCCESS; } +// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { - Shapes multiples = {full_multiples_}; + Shape tmp; + for (size_t i = 0; i < full_multiples_.size(); ++i) { + if (full_multiples_[i] != 1) { + tmp.push_back(full_multiples_[i]); + } else { + tmp.push_back(inputs_shape_[0][i]); + } + } + Shapes multiples = {tmp}; + MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is " + << ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp); return CheckStrategyValue(strategy, multiples); } @@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() { MS_EXCEPTION_IF_NULL(strategy_); std::vector stra = strategy_->GetInputDim(); if (stra.empty()) { - MS_LOG(ERROR) << name_ << "The strategy is empty"; + MS_LOG(ERROR) << name_ << ": The strategy is empty"; return FAILED; } if (full_multiples_.size() != stra[0].size()) { @@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() { slice_multiples_ = full_multiples_; for (size_t i = 0; i < full_multiples_.size(); ++i) { + if (full_multiples_[i] == 1) { + continue; + } slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i]; } return SUCCESS; @@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() { TensorMap input_tensor_map; TensorMap output_tensor_map; if (inputs_shape_.empty() || outputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty"; + MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty"; return FAILED; } - // the input tensor cannot be split + // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - input_tensor_map.push_back(MAP_NONE); + input_tensor_map.push_back(inputs_shape_[0].size() - i - 1); + } + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + if (full_multiples_[i] != 1) { + input_tensor_map[i] = MAP_NONE; + } } // cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices. @@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() { void TileInfo::UpdateMultiples(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() != 3) { - MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3"; + MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3"; } if (!IsValueNode(cnode->input(2))) { - MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple."; + MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple."; } auto func_graph = cnode->func_graph(); @@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) { Shape multiples_split(full_multiples_.size(), 1); Shapes splittable_inputs = {multiples_split}; + // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input std::vector sp_vector; + Shape tmp_input_shape = full_multiples_; + for (size_t i = 0; i < full_multiples_.size(); ++i) { + if (full_multiples_[i] == 0) { + tmp_input_shape[i] = inputs_shape_[0][i]; + } + } Shapes tmp_inputs_shape = {full_multiples_}; if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { return FAILED; diff --git a/tests/ut/python/parallel/test_select.py b/tests/ut/python/parallel/test_select.py index bfec7a5ba1..2c05e53143 100644 --- a/tests/ut/python/parallel/test_select.py +++ b/tests/ut/python/parallel/test_select.py @@ -94,6 +94,15 @@ def test_select_model_parallel(): compile_net(net) +def test_select_mirror(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 2, 2), (1, 2, 2)) + strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2)) + net = Net(_w1, _w2, strategy1, strategy2) + compile_net(net) + + def test_select_auto_parallel(): context.set_auto_parallel_context( parallel_mode="auto_parallel", device_num=8, global_rank=0) diff --git a/tests/ut/python/parallel/test_tile.py b/tests/ut/python/parallel/test_tile.py index 6022fb0689..64731bf6dc 100644 --- a/tests/ut/python/parallel/test_tile.py +++ b/tests/ut/python/parallel/test_tile.py @@ -52,20 +52,38 @@ class Net2(Cell): out = self.tile(out, (8, 8, 4, 2)) return out +class Net3(Cell): + def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.tile = P.Tile().shard(strategy2) + if is_parameter: + self.weight = Parameter(weight, "w1") + else: + self.weight = weight + self.mul2 = P.Mul() + + def construct(self, x, b): + out = self.tile(self.weight, (8, 1, 1)) + out = self.mul(x, out) + return out + _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) _w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) _w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) -def compile_net(net): - context.set_context(save_graphs=False) +def compile_net(net, x=_b, b=_b): + context.set_context(save_graphs=True) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) train_net = TrainOneStepCell(net, optimizer) train_net.set_auto_parallel() train_net.set_train() - _executor.compile(train_net, _x, _b) + _executor.compile(train_net, x, b) context.reset_auto_parallel_context() @@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split(): compile_net(net) +def test_tile_tensor_no_full_split2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1), (2, 2, 1)) + strategy2 = ((2, 2, 1),) + net = Net3(_w1, strategy1, strategy2) + compile_net(net, _x1, _b) + + def test_tile_output(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((2, 2, 2), (2, 2, 2)) @@ -108,6 +134,7 @@ def test_tile_output(): net = Net2(_w2, strategy1, strategy2) compile_net(net) + def test_tile_output_no_full_split(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((2, 2, 2), (2, 2, 2)) @@ -123,7 +150,14 @@ def test_tile_no_strategy(): net = Net2(_w2, strategy1, strategy2) compile_net(net) + def test_tile_auto_parallel(): context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) net = Net2(_w2) compile_net(net) + + +def test_tile_auto_parallel_2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net = Net3(_w1) + compile_net(net, _x1, _b)