Merge pull request !6411 from Xiaoda/23-fix-the-triangle-elimination-problemtags/v1.0.0
| @@ -246,14 +246,15 @@ struct SourceEliminationDecision : public Decision { | |||
| */ | |||
| struct TriangleEliminationDecision : public Decision { | |||
| TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, | |||
| StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) | |||
| StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) | |||
| : eliminated_op_strategy_(std::move(elimi_stra)), | |||
| eliminated_op_cost_(std::move(elimi_op_cost)), | |||
| left_edge_cost_(std::move(l_edge_cost)), | |||
| right_edge_cost_(std::move(r_edge_cost)), | |||
| left_node_strategy_(std::move(left_stra)), | |||
| left_node_cost_(std::move(l_node_cost)), | |||
| right_node_strategy_(std::move(right_stra)) { | |||
| right_node_strategy_(std::move(right_stra)), | |||
| right_node_cost_(std::move(r_node_cost)) { | |||
| type_ = DecisionType::TRIANGLE_ELIMINATION; | |||
| } | |||
| @@ -264,6 +265,7 @@ struct TriangleEliminationDecision : public Decision { | |||
| StrategyPtr left_node_strategy_; | |||
| CostPtr left_node_cost_; | |||
| StrategyPtr right_node_strategy_; | |||
| CostPtr right_node_cost_; | |||
| MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); | |||
| }; | |||
| @@ -199,9 +199,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | |||
| left_edge->set_selected_cost(decision->left_edge_cost_); | |||
| right_edge->set_selected_cost(decision->right_edge_cost_); | |||
| // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. | |||
| // 'left_node' recovers the strategy. | |||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | |||
| right_node->CheckSelectedStrategy(decision->right_node_strategy_); | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| // 'right_node' recovers the strategy. | |||
| MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | |||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | |||
| } else { | |||
| // In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency. | |||
| right_node->CheckSelectedStrategy(decision->right_node_strategy_); | |||
| } | |||
| MS_LOG(INFO) << "Recover triangleElimination succeeded."; | |||
| } else if ((*rit)->isa<StarElimination>()) { | |||
| auto elimination = (*rit)->cast<StarEliminationPtr>(); | |||
| @@ -40,6 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| @@ -154,6 +155,14 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||
| } | |||
| auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite(); | |||
| TRIANGLE_STRATEGY_OVERWRITE = overwrite; | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| MS_LOG(INFO) << "triangle_strategy_overwrite: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "triangle_strategy_overwrite: false."; | |||
| } | |||
| // RUN_PHASE | |||
| auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| if (phase != 0 && phase != 1) { | |||
| @@ -1294,8 +1303,17 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| auto decision = std::make_shared<TriangleEliminationDecision>( | |||
| elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| new_computation += right_op_cost->computation_cost_; | |||
| new_memory += right_op_cost->memory_with_reuse_; | |||
| new_commu_cost += right_op_cost->communication_cost_; | |||
| new_commu_forward += right_op_cost->communication_forward_; | |||
| new_commu_without += right_op_cost->communication_without_parameter_; | |||
| } | |||
| auto decision = | |||
| std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, | |||
| left_op_stra, left_node_cost, right_op_stra, right_op_cost); | |||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = new_commu_without; | |||
| new_cost->communication_with_partial_para_ = | |||
| @@ -46,6 +46,7 @@ extern bool FULLY_USE_DEVICES; | |||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||
| extern bool MULTI_SUBGRAPHS; | |||
| extern int32_t RUN_PHASE; | |||
| extern bool TRIANGLE_STRATEGY_OVERWRITE; | |||
| class CostGraph { | |||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | |||
| @@ -64,6 +64,7 @@ void CostModelContext::ResetAlgoParameters() { | |||
| tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | |||
| elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||
| } | |||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | |||
| @@ -133,6 +134,8 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | |||
| elementwise_stra_follow_ = elementwise_follow; | |||
| } | |||
| void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; } | |||
| void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | |||
| struct CostRegister { | |||
| @@ -44,6 +44,8 @@ namespace parallel { | |||
| #define DEFAULT_RUN_PHASE 0 | |||
| #define TRAINING_PHASE 0 | |||
| #define INFERENCE_PHASE 1 | |||
| #define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true; | |||
| class CostModelContext { | |||
| public: | |||
| ~CostModelContext() = default; | |||
| @@ -133,6 +135,9 @@ class CostModelContext { | |||
| void set_elementwise_stra_follow(bool); | |||
| bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | |||
| void set_triangle_strategy_overwrite(bool); | |||
| bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; } | |||
| void set_run_phase(int32_t); | |||
| int32_t run_phase() const { return run_phase_; } | |||
| @@ -167,6 +172,10 @@ class CostModelContext { | |||
| // MULTI_SUBGRAPHS | |||
| bool is_multi_subgraphs_; | |||
| // In the recovery phase of DP algorithm, when encountering triangle structure, | |||
| // whether overwrite the right-node strategy | |||
| bool triangle_strategy_overwrite_; | |||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | |||
| int32_t costmodel_allreduce_fusion_algorithm_; | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||
| from mindspore import context, Tensor, Parameter | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x): | |||
| predict = self.network(x) | |||
| return self.loss(predict) | |||
| class GradWarp(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWarp, self).__init__() | |||
| self.network = network | |||
| def construct(self, x): | |||
| return grad_all(self.network)(x) | |||
| def test_triangle_strategy_consistency(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.mul1 = P.Mul().shard(((2, 4), (2, 4))) | |||
| self.mul2 = P.Mul() | |||
| self.ba1 = P.BiasAdd() | |||
| self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight") | |||
| self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias") | |||
| self.add = P.TensorAdd().shard(((1, 8), (1, 8))) | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| out = self.mul1(x, self.weight) | |||
| mul_out = self.mul2(out, self.weight) | |||
| ba_out = self.ba1(out, self.bias) | |||
| ta_out = self.add(mul_out, ba_out) | |||
| out = self.relu(ta_out) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||
| net = NetWithLoss(Net()) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||