From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengpull/15766/MERGE
| @@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo { | |||
| protected: | |||
| Status GetAttrs() override { return SUCCESS; } | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferMirrorOps() override { return SUCCESS; } | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferTensorInfo() override; | |||
| Status InferDevMatrixShape() override; | |||
| @@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||
| Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Shapes multiples = {full_multiples_}; | |||
| Shape tmp; | |||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | |||
| if (full_multiples_[i] != 1) { | |||
| tmp.push_back(full_multiples_[i]); | |||
| } else { | |||
| tmp.push_back(inputs_shape_[0][i]); | |||
| } | |||
| } | |||
| Shapes multiples = {tmp}; | |||
| MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is " | |||
| << ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp); | |||
| return CheckStrategyValue(strategy, multiples); | |||
| } | |||
| @@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() { | |||
| MS_EXCEPTION_IF_NULL(strategy_); | |||
| std::vector<Dimensions> stra = strategy_->GetInputDim(); | |||
| if (stra.empty()) { | |||
| MS_LOG(ERROR) << name_ << "The strategy is empty"; | |||
| MS_LOG(ERROR) << name_ << ": The strategy is empty"; | |||
| return FAILED; | |||
| } | |||
| if (full_multiples_.size() != stra[0].size()) { | |||
| @@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() { | |||
| slice_multiples_ = full_multiples_; | |||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | |||
| if (full_multiples_[i] == 1) { | |||
| continue; | |||
| } | |||
| slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i]; | |||
| } | |||
| return SUCCESS; | |||
| @@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() { | |||
| TensorMap input_tensor_map; | |||
| TensorMap output_tensor_map; | |||
| if (inputs_shape_.empty() || outputs_shape_.empty()) { | |||
| MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty"; | |||
| MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty"; | |||
| return FAILED; | |||
| } | |||
| // the input tensor cannot be split | |||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||
| for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | |||
| input_tensor_map.push_back(MAP_NONE); | |||
| input_tensor_map.push_back(inputs_shape_[0].size() - i - 1); | |||
| } | |||
| for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { | |||
| if (full_multiples_[i] != 1) { | |||
| input_tensor_map[i] = MAP_NONE; | |||
| } | |||
| } | |||
| // cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices. | |||
| @@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() { | |||
| void TileInfo::UpdateMultiples(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != 3) { | |||
| MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3"; | |||
| MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3"; | |||
| } | |||
| if (!IsValueNode<ValueTuple>(cnode->input(2))) { | |||
| MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple."; | |||
| MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple."; | |||
| } | |||
| auto func_graph = cnode->func_graph(); | |||
| @@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) { | |||
| Shape multiples_split(full_multiples_.size(), 1); | |||
| Shapes splittable_inputs = {multiples_split}; | |||
| // if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input | |||
| std::vector<StrategyPtr> sp_vector; | |||
| Shape tmp_input_shape = full_multiples_; | |||
| for (size_t i = 0; i < full_multiples_.size(); ++i) { | |||
| if (full_multiples_[i] == 0) { | |||
| tmp_input_shape[i] = inputs_shape_[0][i]; | |||
| } | |||
| } | |||
| Shapes tmp_inputs_shape = {full_multiples_}; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { | |||
| return FAILED; | |||
| @@ -94,6 +94,15 @@ def test_select_model_parallel(): | |||
| compile_net(net) | |||
| def test_select_mirror(): | |||
| context.set_auto_parallel_context( | |||
| parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 2, 2), (1, 2, 2)) | |||
| strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2)) | |||
| net = Net(_w1, _w2, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_select_auto_parallel(): | |||
| context.set_auto_parallel_context( | |||
| parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| @@ -52,20 +52,38 @@ class Net2(Cell): | |||
| out = self.tile(out, (8, 8, 4, 2)) | |||
| return out | |||
| class Net3(Cell): | |||
| def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(strategy1) | |||
| self.tile = P.Tile().shard(strategy2) | |||
| if is_parameter: | |||
| self.weight = Parameter(weight, "w1") | |||
| else: | |||
| self.weight = weight | |||
| self.mul2 = P.Mul() | |||
| def construct(self, x, b): | |||
| out = self.tile(self.weight, (8, 1, 1)) | |||
| out = self.mul(x, out) | |||
| return out | |||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| context.set_context(save_graphs=False) | |||
| def compile_net(net, x=_b, b=_b): | |||
| context.set_context(save_graphs=True) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| train_net.set_train() | |||
| _executor.compile(train_net, _x, _b) | |||
| _executor.compile(train_net, x, b) | |||
| context.reset_auto_parallel_context() | |||
| @@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split(): | |||
| compile_net(net) | |||
| def test_tile_tensor_no_full_split2(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 2, 1), (2, 2, 1)) | |||
| strategy2 = ((2, 2, 1),) | |||
| net = Net3(_w1, strategy1, strategy2) | |||
| compile_net(net, _x1, _b) | |||
| def test_tile_output(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 2, 2), (2, 2, 2)) | |||
| @@ -108,6 +134,7 @@ def test_tile_output(): | |||
| net = Net2(_w2, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_tile_output_no_full_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 2, 2), (2, 2, 2)) | |||
| @@ -123,7 +150,14 @@ def test_tile_no_strategy(): | |||
| net = Net2(_w2, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_tile_auto_parallel(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| net = Net2(_w2) | |||
| compile_net(net) | |||
| def test_tile_auto_parallel_2(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| net = Net3(_w1) | |||
| compile_net(net, _x1, _b) | |||