| @@ -273,19 +273,21 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { | |||
| int64_t input_except_n_shards = | |||
| std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| int64_t weight_shards = | |||
| std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| std::accumulate(weight_strategy.begin(), weight_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| bool is_data_parallel = (input_except_n_shards * weight_shards == 1); | |||
| if (!is_data_parallel) { | |||
| if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) { | |||
| MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_; | |||
| MS_LOG(ERROR) << name_ << ": It is not data parallel, the value of dilation must be 1, but got " << dilation_; | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (group_ != 1) { | |||
| MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_; | |||
| return FAILED; | |||
| } | |||
| if (group_ != 1 && (weight_strategy[0] != 1 || weight_strategy[1] != 1)) { | |||
| MS_LOG(ERROR) << name_ << ": The group is " << group_ | |||
| << ", the cout and cin can not be split, but the shard num of cout is " << weight_strategy[0] | |||
| << ", the shard num of cin is " << weight_strategy[1]; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -143,7 +143,7 @@ def test_conv2d_model_parallel_dilation(): | |||
| def test_conv2d_model_parallel_group(): | |||
| """ | |||
| Feature: test conv2d model parallel and group is not 1 | |||
| Description: model parallel and group is not 1 | |||
| Description: split cin and cout, and group is not 1 | |||
| Expectation: compile failed | |||
| """ | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| @@ -155,6 +155,20 @@ def test_conv2d_model_parallel_group(): | |||
| compile_net(net) | |||
| def test_conv2d_model_parallel_group2(): | |||
| """ | |||
| Feature: test conv2d model parallel and group is not 1 | |||
| Description: has not to split cin and cout, and group is not 1 | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((8, 1, 1, 1),) | |||
| net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_conv2d_model_parallel2(): | |||
| """ | |||
| Feature: same mode, stride = kernel_size, no need exchange | |||
| @@ -41,11 +41,11 @@ class Net(Cell): | |||
| class Net2(Cell): | |||
| def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, | |||
| def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, group=1, | |||
| strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size, | |||
| pad_mode=pad_mode, stride=stride).shard(strategy1) | |||
| pad_mode=pad_mode, stride=stride, group=group).shard(strategy1) | |||
| self.neg = P.Neg().shard(strategy2) | |||
| self.weight = Parameter(conv2d_weight, "w1") | |||
| @@ -60,6 +60,7 @@ _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32) | |||
| _w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32) | |||
| _w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) | |||
| _w5 = Tensor(np.ones([8, 8, 4, 4]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) | |||
| @@ -85,6 +86,20 @@ def test_conv2d_transpose_data_parallel(): | |||
| compile_net(net) | |||
| def test_conv2d_transpose_group(): | |||
| """ | |||
| Feature: test group is not 1 | |||
| Description: shard n/h/w, and group is 2 | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((8, 1, 1, 1),) | |||
| net = Net2(_w5, out_channel=8, kernel_size=4, pad_mode="same", stride=2, group=2, strategy1=strategy1, | |||
| strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_conv2d_transpose_model_parallel1(): | |||
| """ | |||
| Feature: test model parallel strategy | |||