Merge pull request !177 from yao_yf/fix_auto_parallel_prelutags/v0.2.0-alpha
| @@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) { | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if ((stra[0][PRELU_CHANNEL_INDEX] != PRELU_CHANNEL_STRATEGY) || (stra[1][0] != PRELU_CHANNEL_STRATEGY)) { | |||||
| if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) { | |||||
| if (is_auto_parallel_) { | if (is_auto_parallel_) { | ||||
| MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; | MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; | ||||
| } else { | } else { | ||||
| @@ -146,11 +146,10 @@ TEST_F(TestPReLUInfo, CheckStrategy1) { | |||||
| } | } | ||||
| TEST_F(TestPReLUInfo, CheckStrategy2) { | TEST_F(TestPReLUInfo, CheckStrategy2) { | ||||
| // Success: {{2,1,8,16},{1}} | |||||
| std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {4}}; | std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {4}}; | ||||
| StrategyPtr strategy = NewStrategy(0, inputs); | StrategyPtr strategy = NewStrategy(0, inputs); | ||||
| Status ret = prelu->Init(strategy); | Status ret = prelu->Init(strategy); | ||||
| ASSERT_EQ(ret, FAILED); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| } | } | ||||
| TEST_F(TestPReLUInfo, AutoStrategy1) { | TEST_F(TestPReLUInfo, AutoStrategy1) { | ||||
| @@ -252,11 +251,10 @@ TEST_F(TestPReLUInfo, CheckStrategy_2d1) { | |||||
| } | } | ||||
| TEST_F(TestPReLUInfo, CheckStrategy_2d2) { | TEST_F(TestPReLUInfo, CheckStrategy_2d2) { | ||||
| // Success: {{2,1,8,16},{1}} | |||||
| std::vector<Dimensions> inputs = {{128, 4}, {4}}; | std::vector<Dimensions> inputs = {{128, 4}, {4}}; | ||||
| StrategyPtr strategy = NewStrategy(0, inputs); | StrategyPtr strategy = NewStrategy(0, inputs); | ||||
| Status ret = prelu_2d->Init(strategy); | Status ret = prelu_2d->Init(strategy); | ||||
| ASSERT_EQ(ret, FAILED); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| } | } | ||||
| TEST_F(TestPReLUInfo, AutoStrategy_2d1) { | TEST_F(TestPReLUInfo, AutoStrategy_2d1) { | ||||
| @@ -149,3 +149,20 @@ def test_prelu_parallel_success3(): | |||||
| w = Tensor(np.random.rand(16),dtype=ms.float32) | w = Tensor(np.random.rand(16),dtype=ms.float32) | ||||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | ||||
| _executor.compile(net, x, y, w) | _executor.compile(net, x, y, w) | ||||
| def test_prelu_parallel_success4(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy): | |||||
| super().__init__() | |||||
| self.prelu = P.PReLU().set_strategy(strategy) | |||||
| def construct(self, x, y): | |||||
| out = self.prelu(x, y) | |||||
| return out | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=64, global_rank=0) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| strategy = ((2, 4, 4, 2), (4, )) | |||||
| x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32) | |||||
| w = Tensor(np.random.rand(16),dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net(strategy))) | |||||
| _executor.compile(net, x, w) | |||||