From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengpull/15192/MERGE
| @@ -3232,7 +3232,24 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode) | |||
| return parameter_users_info; | |||
| } | |||
| Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> ¶m_info) { | |||
| RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) { | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| RankList stage_device_list = g_device_manager->GetDeviceListInThisStage(); | |||
| Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array(); | |||
| Shape tensor_map = tensor_info.tensor_layout().tensor_map().array(); | |||
| DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape); | |||
| RankList group_devices; | |||
| if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Get devices by tensor map failed"; | |||
| } | |||
| std::sort(group_devices.begin(), group_devices.end()); | |||
| return group_devices; | |||
| } | |||
| ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) { | |||
| auto user_cnode = param_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| auto user_input_index = param_info.second; | |||
| @@ -3245,10 +3262,14 @@ Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> ¶m_info) { | |||
| << ", but the index is " << user_input_index - 1; | |||
| } | |||
| TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1]; | |||
| ParameterSliceInfo parameter_slice_info; | |||
| parameter_slice_info.slice_shape = tensor_info.slice_shape(); | |||
| parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info); | |||
| MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1 | |||
| << ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is " | |||
| << ShapeToString(tensor_info.shape()); | |||
| return tensor_info.slice_shape(); | |||
| << ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is " | |||
| << tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks; | |||
| return parameter_slice_info; | |||
| } | |||
| void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| @@ -3262,13 +3283,24 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto parameter_name = parameter_users_info.first; | |||
| MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users"; | |||
| auto first_user = users_set.pop(); | |||
| Shape first_user_slice_shape = ParameterSliceShape(first_user); | |||
| ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user); | |||
| Shape first_user_slice_shape = parameter_slice_info.slice_shape; | |||
| RankList first_user_group_list = parameter_slice_info.group_ranks; | |||
| for (auto &user : users_set) { | |||
| Shape user_slice_shape = ParameterSliceShape(user); | |||
| ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user); | |||
| Shape user_slice_shape = user_slice_info.slice_shape; | |||
| RankList user_group_list = user_slice_info.group_ranks; | |||
| if (first_user_slice_shape != user_slice_shape) { | |||
| MS_LOG(EXCEPTION) << "The parameter: " << parameter_name | |||
| << " has multiple users, but the split strategies are different"; | |||
| << " has multiple users, but the slice shapes are different"; | |||
| } | |||
| if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) { | |||
| MS_LOG(EXCEPTION) << "The parameter: " << parameter_name | |||
| << " has multiple users, but the group rank list are different, " | |||
| << "the group rank list for first user is " << first_user_group_list | |||
| << ", and the group rank list for this user is " << user_group_list; | |||
| } | |||
| } | |||
| } | |||
| @@ -46,6 +46,11 @@ struct LossNodeInfo { | |||
| CNodePtr loss_node = nullptr; | |||
| }; | |||
| struct ParameterSliceInfo { | |||
| Shape slice_shape; | |||
| RankList group_ranks; | |||
| }; | |||
| std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index); | |||
| void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); | |||
| @@ -47,9 +47,22 @@ class Net2(Cell): | |||
| return out | |||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| class Net3(Cell): | |||
| def __init__(self, mul_weight, strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.mul = P.MatMul().shard(strategy1) | |||
| self.mul2 = P.MatMul().shard(strategy2) | |||
| self.mul_weight = Parameter(mul_weight, "w1") | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.mul_weight) | |||
| out = self.mul2(out, self.mul_weight) | |||
| return out | |||
| _x = Tensor(np.ones([16, 16]), dtype=ms.float32) | |||
| _w = Tensor(np.ones([16, 16]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([16, 16]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| @@ -63,16 +76,16 @@ def compile_net(net): | |||
| def test_parameter_same_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy1 = ((16, 1), (16, 1)) | |||
| strategy2 = ((16, 1), (16, 1)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_parameter_different_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((4, 4, 1), (4, 4, 1)) | |||
| strategy1 = ((16, 1), (16, 1)) | |||
| strategy2 = ((4, 4), (4, 4)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| @@ -80,16 +93,25 @@ def test_parameter_different_split(): | |||
| def test_input_same_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy1 = ((16, 1), (16, 1)) | |||
| strategy2 = ((16, 1), (16, 1)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_input_different_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((4, 4, 1), (4, 4, 1)) | |||
| strategy1 = ((16, 1), (16, 1)) | |||
| strategy2 = ((4, 4), (4, 4)) | |||
| net = Net2(_w, strategy1, strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_parameter_different_group(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((1, 2), (2, 1)) | |||
| strategy2 = ((8, 2), (2, 1)) | |||
| net = Net3(_w, strategy1, strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||