Browse Source

!20490 update check strategy for conv2d

Merge pull request !20490 from yangzhenzhang/update-check-strategy-for-conv2d
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
6061194083
3 changed files with 44 additions and 7 deletions
  1. +29
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  2. +5
    -4
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
  3. +10
    -0
      tests/ut/python/parallel/test_conv2d.py

+ 29
- 3
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -124,7 +124,29 @@ Status Conv2DInfo::GetAttrsBase() {

Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }

Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) {
if (outputs_shape_[0][2] % h_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
"of h dimension";
return FAILED;
}

if (outputs_shape_[0][3] % w_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
"of w dimension";
return FAILED;
}

return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}

if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
@@ -642,6 +664,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
}

Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}

if (pad_mode_ != 1) { // only support same mode
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
return FAILED;
@@ -649,18 +675,18 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st

if (h_strategy > 1) {
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape";
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape";
return FAILED;
}

if (kernel_size_[0] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride";
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride";
return FAILED;
}
}

if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape";
MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape";
return FAILED;
}



+ 5
- 4
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h View File

@@ -46,6 +46,7 @@ class Conv2DInfo : public OperatorInfo {
Status GetAttrsBase();
Status GetAttrs() override;
Status CheckStrategyBase(const StrategyPtr &strategy);
Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy);
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
@@ -117,10 +118,10 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
Status InferTensorMap() override;
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor

Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
void InferNewPadList();
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override;
void InferNewPadList() override;
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;

private:
Shape out_shape_;


+ 10
- 0
tests/ut/python/parallel/test_conv2d.py View File

@@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest

import mindspore as ms
from mindspore import context, Tensor, Parameter
@@ -72,3 +73,12 @@ def test_conv2d_model_parallel2():
strategy2 = ((32, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
compile_net(net)


def test_conv2d_output_can_not_divisible_by_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)

Loading…
Cancel
Save