From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengpull/15766/MERGE
| @@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo { | |||||
| protected: | protected: | ||||
| Status GetAttrs() override { return SUCCESS; } | Status GetAttrs() override { return SUCCESS; } | ||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| Status InferMirrorOps() override { return SUCCESS; } | |||||
| Status InferForwardCommunication() override { return SUCCESS; } | Status InferForwardCommunication() override { return SUCCESS; } | ||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| @@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||||
| Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { | Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| Shapes multiples = {full_multiples_}; | |||||
| Shape tmp; | |||||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | |||||
| if (full_multiples_[i] != 1) { | |||||
| tmp.push_back(full_multiples_[i]); | |||||
| } else { | |||||
| tmp.push_back(inputs_shape_[0][i]); | |||||
| } | |||||
| } | |||||
| Shapes multiples = {tmp}; | |||||
| MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is " | |||||
| << ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp); | |||||
| return CheckStrategyValue(strategy, multiples); | return CheckStrategyValue(strategy, multiples); | ||||
| } | } | ||||
| @@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() { | |||||
| MS_EXCEPTION_IF_NULL(strategy_); | MS_EXCEPTION_IF_NULL(strategy_); | ||||
| std::vector<Dimensions> stra = strategy_->GetInputDim(); | std::vector<Dimensions> stra = strategy_->GetInputDim(); | ||||
| if (stra.empty()) { | if (stra.empty()) { | ||||
| MS_LOG(ERROR) << name_ << "The strategy is empty"; | |||||
| MS_LOG(ERROR) << name_ << ": The strategy is empty"; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (full_multiples_.size() != stra[0].size()) { | if (full_multiples_.size() != stra[0].size()) { | ||||
| @@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() { | |||||
| slice_multiples_ = full_multiples_; | slice_multiples_ = full_multiples_; | ||||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | for (size_t i = 0; i < full_multiples_.size(); ++i) { | ||||
| if (full_multiples_[i] == 1) { | |||||
| continue; | |||||
| } | |||||
| slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i]; | slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i]; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() { | |||||
| TensorMap input_tensor_map; | TensorMap input_tensor_map; | ||||
| TensorMap output_tensor_map; | TensorMap output_tensor_map; | ||||
| if (inputs_shape_.empty() || outputs_shape_.empty()) { | if (inputs_shape_.empty() || outputs_shape_.empty()) { | ||||
| MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty"; | |||||
| MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty"; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // the input tensor cannot be split | |||||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||||
| for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | ||||
| input_tensor_map.push_back(MAP_NONE); | |||||
| input_tensor_map.push_back(inputs_shape_[0].size() - i - 1); | |||||
| } | |||||
| for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | |||||
| if (full_multiples_[i] != 1) { | |||||
| input_tensor_map[i] = MAP_NONE; | |||||
| } | |||||
| } | } | ||||
| // cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices. | // cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices. | ||||
| @@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() { | |||||
| void TileInfo::UpdateMultiples(const CNodePtr &cnode) { | void TileInfo::UpdateMultiples(const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (cnode->size() != 3) { | if (cnode->size() != 3) { | ||||
| MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3"; | |||||
| MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3"; | |||||
| } | } | ||||
| if (!IsValueNode<ValueTuple>(cnode->input(2))) { | if (!IsValueNode<ValueTuple>(cnode->input(2))) { | ||||
| MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple."; | |||||
| MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple."; | |||||
| } | } | ||||
| auto func_graph = cnode->func_graph(); | auto func_graph = cnode->func_graph(); | ||||
| @@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) { | |||||
| Shape multiples_split(full_multiples_.size(), 1); | Shape multiples_split(full_multiples_.size(), 1); | ||||
| Shapes splittable_inputs = {multiples_split}; | Shapes splittable_inputs = {multiples_split}; | ||||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||||
| std::vector<StrategyPtr> sp_vector; | std::vector<StrategyPtr> sp_vector; | ||||
| Shape tmp_input_shape = full_multiples_; | |||||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | |||||
| if (full_multiples_[i] == 0) { | |||||
| tmp_input_shape[i] = inputs_shape_[0][i]; | |||||
| } | |||||
| } | |||||
| Shapes tmp_inputs_shape = {full_multiples_}; | Shapes tmp_inputs_shape = {full_multiples_}; | ||||
| if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { | if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -94,6 +94,15 @@ def test_select_model_parallel(): | |||||
| compile_net(net) | compile_net(net) | ||||
| def test_select_mirror(): | |||||
| context.set_auto_parallel_context( | |||||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((1, 2, 2), (1, 2, 2)) | |||||
| strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2)) | |||||
| net = Net(_w1, _w2, strategy1, strategy2) | |||||
| compile_net(net) | |||||
| def test_select_auto_parallel(): | def test_select_auto_parallel(): | ||||
| context.set_auto_parallel_context( | context.set_auto_parallel_context( | ||||
| parallel_mode="auto_parallel", device_num=8, global_rank=0) | parallel_mode="auto_parallel", device_num=8, global_rank=0) | ||||
| @@ -52,20 +52,38 @@ class Net2(Cell): | |||||
| out = self.tile(out, (8, 8, 4, 2)) | out = self.tile(out, (8, 8, 4, 2)) | ||||
| return out | return out | ||||
| class Net3(Cell): | |||||
| def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().shard(strategy1) | |||||
| self.tile = P.Tile().shard(strategy2) | |||||
| if is_parameter: | |||||
| self.weight = Parameter(weight, "w1") | |||||
| else: | |||||
| self.weight = weight | |||||
| self.mul2 = P.Mul() | |||||
| def construct(self, x, b): | |||||
| out = self.tile(self.weight, (8, 1, 1)) | |||||
| out = self.mul(x, out) | |||||
| return out | |||||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | ||||
| _x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) | |||||
| _w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) | _w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) | ||||
| _w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | _w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | ||||
| _w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) | |||||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | ||||
| def compile_net(net): | |||||
| context.set_context(save_graphs=False) | |||||
| def compile_net(net, x=_b, b=_b): | |||||
| context.set_context(save_graphs=True) | |||||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| train_net = TrainOneStepCell(net, optimizer) | train_net = TrainOneStepCell(net, optimizer) | ||||
| train_net.set_auto_parallel() | train_net.set_auto_parallel() | ||||
| train_net.set_train() | train_net.set_train() | ||||
| _executor.compile(train_net, _x, _b) | |||||
| _executor.compile(train_net, x, b) | |||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| @@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split(): | |||||
| compile_net(net) | compile_net(net) | ||||
| def test_tile_tensor_no_full_split2(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| strategy1 = ((2, 2, 1), (2, 2, 1)) | |||||
| strategy2 = ((2, 2, 1),) | |||||
| net = Net3(_w1, strategy1, strategy2) | |||||
| compile_net(net, _x1, _b) | |||||
| def test_tile_output(): | def test_tile_output(): | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | ||||
| strategy1 = ((2, 2, 2), (2, 2, 2)) | strategy1 = ((2, 2, 2), (2, 2, 2)) | ||||
| @@ -108,6 +134,7 @@ def test_tile_output(): | |||||
| net = Net2(_w2, strategy1, strategy2) | net = Net2(_w2, strategy1, strategy2) | ||||
| compile_net(net) | compile_net(net) | ||||
| def test_tile_output_no_full_split(): | def test_tile_output_no_full_split(): | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | ||||
| strategy1 = ((2, 2, 2), (2, 2, 2)) | strategy1 = ((2, 2, 2), (2, 2, 2)) | ||||
| @@ -123,7 +150,14 @@ def test_tile_no_strategy(): | |||||
| net = Net2(_w2, strategy1, strategy2) | net = Net2(_w2, strategy1, strategy2) | ||||
| compile_net(net) | compile_net(net) | ||||
| def test_tile_auto_parallel(): | def test_tile_auto_parallel(): | ||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | ||||
| net = Net2(_w2) | net = Net2(_w2) | ||||
| compile_net(net) | compile_net(net) | ||||
| def test_tile_auto_parallel_2(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||||
| net = Net3(_w1) | |||||
| compile_net(net, _x1, _b) | |||||