Browse Source

check strategy for conv2d

tags/v1.5.0-rc1
yangzhenzhang 4 years ago
parent
commit
d18c813ee4
6 changed files with 227 additions and 41 deletions
  1. +83
    -39
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  2. +4
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
  3. +14
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.cc
  4. +82
    -2
      tests/ut/python/parallel/test_conv2d.py
  5. +33
    -0
      tests/ut/python/parallel/test_conv2d_transpose.py
  6. +11
    -0
      tests/ut/python/parallel/test_maxpool_avgpool.py

+ 83
- 39
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -143,52 +143,48 @@ Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) {
return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}

Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;

if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
// H dimension
if (kernel_size_[0] > stride_[2] && h_strategy > 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
return FAILED;
}

if (pad_mode_ == 1) { // 'same' mode
if ((kernel_size_[0] > stride_[2] || kernel_size_[1] > stride_[3]) && h_strategy > 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
return FAILED;
}
if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape "
"is not divisible by stride ";
return FAILED;
}

if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) {
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape "
"is not divisible by stride ";
return FAILED;
}
}
// W dimension
if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape "
"is not divisible by stride ";
return FAILED;
}

if (pad_mode_ == 2) { // 'valid' mode
if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) {
if (inputs_shape_[0][3] % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not "
"divisible by stride";
return FAILED;
}

if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
if (w_slice_shape < ((kernel_size_[1] - stride_[3] + 1) / 2)) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
"not divisible by stride ";
<< ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
"smaller than (k - s + 1) / 2";
return FAILED;
}

if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
"not divisible by stride ";
if (kernel_size_[1] - stride_[3] == 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1";
return FAILED;
}
}
@@ -196,6 +192,53 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;

if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
return FAILED;
}

if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}

if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}

return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}

if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
}

if (pad_mode_ == 1) { // 'same' mode
return CheckHWStrategySameMode(h_strategy, w_strategy);
}

if (pad_mode_ == 2) { // 'valid' mode
return CheckHWStrategyValidMode(h_strategy, w_strategy);
}

return SUCCESS;
}

Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
@@ -493,10 +536,18 @@ void Conv2DInfo::InferSendRecvFlag() {
<< right_need_recv_;

if (left_need_send_) {
if (left_rank_overlap_right_size_ > input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(left_rank_id_);
}

if (right_need_send_) {
if (right_rank_overlap_left_size_ > input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
<< ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(right_rank_id_);
}

@@ -869,15 +920,8 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st
}

if (h_strategy > 1) {
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape";
return FAILED;
}

if (kernel_size_[0] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride";
return FAILED;
}
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension";
return FAILED;
}

if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {


+ 4
- 0
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h View File

@@ -115,6 +115,10 @@ class Conv2DInfo : public OperatorInfo {
virtual void InferNewPadList();
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);

private:
Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy);
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
};

class Conv2DBackpropInputInfo : public Conv2DInfo {


+ 14
- 0
mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.cc View File

@@ -76,6 +76,20 @@ Status MaxPoolInfo::GetAttrs() {
}

Status MaxPoolInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (outputs_shape_[0][2] % h_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
"of h dimension";
return FAILED;
}

if (outputs_shape_[0][3] % w_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
"of w dimension";
return FAILED;
}

if (h_strategy > 1) {
if (kernel_size_[2] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": It does not support to split H dimension when kernel_size > stride";


+ 82
- 2
tests/ut/python/parallel/test_conv2d.py View File

@@ -38,18 +38,20 @@ class Net(Cell):


_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
_w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)


def compile_net(net):
def compile_net(net, input_x=_x):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
_executor.compile(train_net, input_x, _b)
context.reset_auto_parallel_context()


@@ -85,6 +87,12 @@ def test_conv2d_model_parallel3():
compile_net(net)


def test_conv2d_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
compile_net(net)


def test_conv2d_model_parallel4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
@@ -102,6 +110,24 @@ def test_conv2d_left_and_right_no_need_to_send():
compile_net(net)


def test_conv2d_kernel_size_larger_than_stride_and_split_h():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 4, 1),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_conv2d_valid_mode_kernel_size_larger_than_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_conv2d_output_can_not_divisible_by_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
@@ -109,3 +135,57 @@ def test_conv2d_output_can_not_divisible_by_strategy():
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_split_kernel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x2)


def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x2)


def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x2)


def test_kernel_size_larger_than_stride_and_slice_too_small():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_kernel_size_larger_than_stride_and_left_pad_is_0():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)

+ 33
- 0
tests/ut/python/parallel/test_conv2d_transpose.py View File

@@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest

import mindspore as ms
from mindspore import context, Tensor, Parameter
@@ -54,6 +55,8 @@ class Net2(Cell):
_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32)
_w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)


@@ -98,3 +101,33 @@ def test_conv2d_transpose_model_parallel3():
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)


def test_conv2d_transpose_all_rank_no_need_overlap():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)


def test_conv2d_transpose_overlap_size_too_large():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_conv2d_transpose_rank0_no_need_overlap():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)

+ 11
- 0
tests/ut/python/parallel/test_maxpool_avgpool.py View File

@@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest

import mindspore as ms
from mindspore import context, Tensor, Parameter
@@ -98,6 +99,16 @@ def test_maxpool_auto_parallel():
compile_net(net)


def test_maxpool_output_can_not_divisible_by_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_avgpool_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))


Loading…
Cancel
Save