Browse Source

!15192 check layouts for shared parameter

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/15192/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
a492a1cd52
3 changed files with 77 additions and 18 deletions
  1. +39
    -7
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +5
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.h
  3. +33
    -11
      tests/ut/python/parallel/test_parameter_multi_users.py

+ 39
- 7
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -3232,7 +3232,24 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)
return parameter_users_info;
}

Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> &param_info) {
RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();

DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
}

std::sort(group_devices.begin(), group_devices.end());
return group_devices;
}

ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> &param_info) {
auto user_cnode = param_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
auto user_input_index = param_info.second;
@@ -3245,10 +3262,14 @@ Shape ParameterSliceShape(const std::pair<AnfNodePtr, int64_t> &param_info) {
<< ", but the index is " << user_input_index - 1;
}
TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];

ParameterSliceInfo parameter_slice_info;
parameter_slice_info.slice_shape = tensor_info.slice_shape();
parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1
<< ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is "
<< ShapeToString(tensor_info.shape());
return tensor_info.slice_shape();
<< ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
<< tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
return parameter_slice_info;
}

void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
@@ -3262,13 +3283,24 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
auto parameter_name = parameter_users_info.first;
MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
auto first_user = users_set.pop();
Shape first_user_slice_shape = ParameterSliceShape(first_user);
ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
Shape first_user_slice_shape = parameter_slice_info.slice_shape;
RankList first_user_group_list = parameter_slice_info.group_ranks;

for (auto &user : users_set) {
Shape user_slice_shape = ParameterSliceShape(user);
ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
Shape user_slice_shape = user_slice_info.slice_shape;
RankList user_group_list = user_slice_info.group_ranks;
if (first_user_slice_shape != user_slice_shape) {
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
<< " has multiple users, but the split strategies are different";
<< " has multiple users, but the slice shapes are different";
}

if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
<< " has multiple users, but the group rank list are different, "
<< "the group rank list for first user is " << first_user_group_list
<< ", and the group rank list for this user is " << user_group_list;
}
}
}


+ 5
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -46,6 +46,11 @@ struct LossNodeInfo {
CNodePtr loss_node = nullptr;
};

struct ParameterSliceInfo {
Shape slice_shape;
RankList group_ranks;
};

std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);


+ 33
- 11
tests/ut/python/parallel/test_parameter_multi_users.py View File

@@ -47,9 +47,22 @@ class Net2(Cell):
return out


_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
class Net3(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.MatMul().shard(strategy1)
self.mul2 = P.MatMul().shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")

def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.mul2(out, self.mul_weight)
return out


_x = Tensor(np.ones([16, 16]), dtype=ms.float32)
_w = Tensor(np.ones([16, 16]), dtype=ms.float32)
_b = Tensor(np.ones([16, 16]), dtype=ms.float32)


def compile_net(net):
@@ -63,16 +76,16 @@ def compile_net(net):

def test_parameter_same_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1), (16, 1, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((16, 1), (16, 1))
net = Net(_w, strategy1, strategy2)
compile_net(net)


def test_parameter_different_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((4, 4, 1), (4, 4, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((4, 4), (4, 4))
net = Net(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
@@ -80,16 +93,25 @@ def test_parameter_different_split():

def test_input_same_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1), (16, 1, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((16, 1), (16, 1))
net = Net(_w, strategy1, strategy2)
compile_net(net)


def test_input_different_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((4, 4, 1), (4, 4, 1))
strategy1 = ((16, 1), (16, 1))
strategy2 = ((4, 4), (4, 4))
net = Net2(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_parameter_different_group():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 2), (2, 1))
strategy2 = ((8, 2), (2, 1))
net = Net3(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)

Loading…
Cancel
Save