| @@ -143,52 +143,48 @@ Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { | |||
| if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) { | |||
| int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; | |||
| int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; | |||
| if (pad_mode_ == 0) { // 'pad' mode | |||
| MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W"; | |||
| // H dimension | |||
| if (kernel_size_[0] > stride_[2] && h_strategy > 1) { | |||
| MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride"; | |||
| return FAILED; | |||
| } | |||
| if (pad_mode_ == 1) { // 'same' mode | |||
| if ((kernel_size_[0] > stride_[2] || kernel_size_[1] > stride_[3]) && h_strategy > 1) { | |||
| MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride"; | |||
| return FAILED; | |||
| } | |||
| if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape " | |||
| "is not divisible by stride "; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) { | |||
| if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape " | |||
| "is not divisible by stride "; | |||
| return FAILED; | |||
| } | |||
| } | |||
| // W dimension | |||
| if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape " | |||
| "is not divisible by stride "; | |||
| return FAILED; | |||
| } | |||
| if (pad_mode_ == 2) { // 'valid' mode | |||
| if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) { | |||
| MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride"; | |||
| if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) { | |||
| if (inputs_shape_[0][3] % stride_[3] != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not " | |||
| "divisible by stride"; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) { | |||
| if (w_slice_shape < ((kernel_size_[1] - stride_[3] + 1) / 2)) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " | |||
| "not divisible by stride "; | |||
| << ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is " | |||
| "smaller than (k - s + 1) / 2"; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " | |||
| "not divisible by stride "; | |||
| if (kernel_size_[1] - stride_[3] == 1) { | |||
| MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1"; | |||
| return FAILED; | |||
| } | |||
| } | |||
| @@ -196,6 +192,53 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy) { | |||
| int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; | |||
| int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; | |||
| if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) { | |||
| MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride"; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " | |||
| "not divisible by stride "; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " | |||
| "not divisible by stride "; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { | |||
| if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (pad_mode_ == 0) { // 'pad' mode | |||
| MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W"; | |||
| return FAILED; | |||
| } | |||
| if (pad_mode_ == 1) { // 'same' mode | |||
| return CheckHWStrategySameMode(h_strategy, w_strategy); | |||
| } | |||
| if (pad_mode_ == 2) { // 'valid' mode | |||
| return CheckHWStrategyValidMode(h_strategy, w_strategy); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { | |||
| MS_EXCEPTION_IF_NULL(strategy); | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| @@ -493,10 +536,18 @@ void Conv2DInfo::InferSendRecvFlag() { | |||
| << right_need_recv_; | |||
| if (left_need_send_) { | |||
| if (left_rank_overlap_right_size_ > input_slice_shape_[3]) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_ | |||
| << ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")"; | |||
| } | |||
| send_rank_ids_.push_back(left_rank_id_); | |||
| } | |||
| if (right_need_send_) { | |||
| if (right_rank_overlap_left_size_ > input_slice_shape_[3]) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_ | |||
| << ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")"; | |||
| } | |||
| send_rank_ids_.push_back(right_rank_id_); | |||
| } | |||
| @@ -869,15 +920,8 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st | |||
| } | |||
| if (h_strategy > 1) { | |||
| if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) { | |||
| MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape"; | |||
| return FAILED; | |||
| } | |||
| if (kernel_size_[0] > stride_[2]) { | |||
| MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(ERROR) << name_ << ": Do not support to split h dimension"; | |||
| return FAILED; | |||
| } | |||
| if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) { | |||
| @@ -115,6 +115,10 @@ class Conv2DInfo : public OperatorInfo { | |||
| virtual void InferNewPadList(); | |||
| virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias); | |||
| virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias); | |||
| private: | |||
| Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy); | |||
| Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy); | |||
| }; | |||
| class Conv2DBackpropInputInfo : public Conv2DInfo { | |||
| @@ -76,6 +76,20 @@ Status MaxPoolInfo::GetAttrs() { | |||
| } | |||
| Status MaxPoolInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { | |||
| if (outputs_shape_[0][2] % h_strategy != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy " | |||
| "of h dimension"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_shape_[0][3] % w_strategy != 0) { | |||
| MS_LOG(ERROR) << name_ | |||
| << ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy " | |||
| "of w dimension"; | |||
| return FAILED; | |||
| } | |||
| if (h_strategy > 1) { | |||
| if (kernel_size_[2] > stride_[2]) { | |||
| MS_LOG(ERROR) << name_ << ": It does not support to split H dimension when kernel_size > stride"; | |||
| @@ -38,18 +38,20 @@ class Net(Cell): | |||
| _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) | |||
| _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32) | |||
| _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) | |||
| _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| def compile_net(net, input_x=_x): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| train_net.set_train() | |||
| _executor.compile(train_net, _x, _b) | |||
| _executor.compile(train_net, input_x, _b) | |||
| context.reset_auto_parallel_context() | |||
| @@ -85,6 +87,12 @@ def test_conv2d_model_parallel3(): | |||
| compile_net(net) | |||
| def test_conv2d_auto_parallel(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) | |||
| net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1) | |||
| compile_net(net) | |||
| def test_conv2d_model_parallel4(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) | |||
| strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1)) | |||
| @@ -102,6 +110,24 @@ def test_conv2d_left_and_right_no_need_to_send(): | |||
| compile_net(net) | |||
| def test_conv2d_kernel_size_larger_than_stride_and_split_h(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) | |||
| strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1)) | |||
| strategy2 = ((2, 2, 4, 1),) | |||
| net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_conv2d_valid_mode_kernel_size_larger_than_stride(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((2, 1, 1, 4),) | |||
| net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_conv2d_output_can_not_divisible_by_strategy(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) | |||
| @@ -109,3 +135,57 @@ def test_conv2d_output_can_not_divisible_by_strategy(): | |||
| net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_split_kernel(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net, _x2) | |||
| def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net, _x2) | |||
| def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net, _x2) | |||
| def test_kernel_size_larger_than_stride_and_slice_too_small(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_kernel_size_larger_than_stride_and_left_pad_is_0(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| @@ -54,6 +55,8 @@ class Net2(Cell): | |||
| _x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) | |||
| _w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32) | |||
| _w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32) | |||
| _w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) | |||
| @@ -98,3 +101,33 @@ def test_conv2d_transpose_model_parallel3(): | |||
| net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_conv2d_transpose_all_rank_no_need_overlap(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) | |||
| strategy2 = ((2, 2, 1, 4),) | |||
| net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| compile_net(net) | |||
| def test_conv2d_transpose_overlap_size_too_large(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_conv2d_transpose_rank0_no_need_overlap(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) | |||
| strategy2 = ((2, 2, 1, 4),) | |||
| net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| @@ -98,6 +99,16 @@ def test_maxpool_auto_parallel(): | |||
| compile_net(net) | |||
| def test_maxpool_output_can_not_divisible_by_strategy(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) | |||
| strategy2 = ((1, 1, 1, 8),) | |||
| net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, | |||
| strategy1=strategy1, strategy2=strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_avgpool_data_parallel(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) | |||
| strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) | |||