| @@ -79,6 +79,8 @@ class StrategyWithCost { | |||
| public: | |||
| StrategyWithCost(StrategyPtr strategy, std::vector<TensorInfo> inputs_, std::vector<TensorInfo> outputs_) | |||
| : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} | |||
| StrategyWithCost(StrategyPtr strategy, CostPtrList c_list) | |||
| : strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {} | |||
| StrategyWithCost(const StrategyWithCost &swc) = delete; | |||
| StrategyWithCost(StrategyWithCost &&swc) | |||
| @@ -99,6 +101,7 @@ enum DecisionType { | |||
| EDGE_ELIMINATION, | |||
| MERGE_ELIMINATION, | |||
| CONTRACT_ELIMINATION, | |||
| SOURCE_ELIMINATION, | |||
| TRIANGLE_ELIMINATION, | |||
| STAR_ELIMINATION, | |||
| FINAL_TYPE, | |||
| @@ -199,6 +202,38 @@ struct ContractEliminationDecision : public Decision { | |||
| MS_DECLARE_PARENT(ContractEliminationDecision, Decision); | |||
| }; | |||
| /* 'SourceEliminationDecision' is for the source Elimination in DP algorithm: | |||
| * 1 1,5 | |||
| * / \ // \\ | |||
| * / \ // \\ | |||
| * / \ // \\ | |||
| * / \ // \\ | |||
| * 2 <- 5 -> 3 ==> 2 3 | |||
| * \ / \ / | |||
| * \ / \ / | |||
| * \ / \ / | |||
| * 4 4 | |||
| * | |||
| * In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and | |||
| * no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into | |||
| * '1' new edges are generated to replace the old ones incident to '1' and '5'. | |||
| * | |||
| */ | |||
| struct SourceEliminationDecision : public Decision { | |||
| SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c) | |||
| : op1_strategy_(std::move(op1_stra)), | |||
| op1_cost_(std::move(op1_c)), | |||
| op2_strategy_(std::move(op2_stra)), | |||
| op2_cost_(std::move(op2_c)) { | |||
| type_ = DecisionType::SOURCE_ELIMINATION; | |||
| } | |||
| StrategyPtr op1_strategy_; | |||
| CostPtr op1_cost_; | |||
| StrategyPtr op2_strategy_; | |||
| CostPtr op2_cost_; | |||
| MS_DECLARE_PARENT(SourceEliminationDecision, Decision); | |||
| }; | |||
| /* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: | |||
| * | |||
| * u | |||
| @@ -296,6 +331,7 @@ using OpEliminationDecisionPtr = std::shared_ptr<OpEliminationDecision>; | |||
| using EdgeEliminationDecisionPtr = std::shared_ptr<EdgeEliminationDecision>; | |||
| using MergeEliminationDecisionPtr = std::shared_ptr<MergeEliminationDecision>; | |||
| using ContractEliminationDecisionPtr = std::shared_ptr<ContractEliminationDecision>; | |||
| using SourceEliminationDecisionPtr = std::shared_ptr<SourceEliminationDecision>; | |||
| using TriangleEliminationDecisionPtr = std::shared_ptr<TriangleEliminationDecision>; | |||
| using StarEliminationDecisionPtr = std::shared_ptr<StarEliminationDecision>; | |||
| using FinalDecisionPtr = std::shared_ptr<FinalDecision>; | |||
| @@ -42,66 +42,76 @@ Status GetStrategy(const CostGraphPtr &graph) { | |||
| auto elimi = std::make_shared<OpElimination>(n_edge, l_edge, node, r_edge); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto edges = graph->CheckEdgeElimination(); | |||
| if ((!flag) && (!edges.empty())) { | |||
| // Applying the Edge Elimination | |||
| flag = true; | |||
| auto n_edge = graph->EliminationEdges(edges); | |||
| auto elimi = std::make_shared<EdgeElimination>(n_edge, edges); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| if (!flag) { | |||
| auto edges = graph->CheckEdgeElimination(); | |||
| if (!edges.empty()) { | |||
| // Applying the Edge Elimination | |||
| flag = true; | |||
| auto n_edge = graph->EliminationEdges(edges); | |||
| auto elimi = std::make_shared<EdgeElimination>(n_edge, edges); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| } | |||
| auto merge_node = graph->CheckMergeElimination(); | |||
| if ((!flag) && (merge_node != nullptr)) { | |||
| // Applying the Merge Elimination | |||
| flag = true; | |||
| auto succ_edge = merge_node->GetAliveSuccEdges()[0]; | |||
| auto target_node = graph->EliminationMerge(merge_node); | |||
| auto elimi = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| if (!flag) { | |||
| auto merge_node = graph->CheckMergeElimination(); | |||
| if (merge_node != nullptr) { | |||
| // Applying the Merge Elimination | |||
| flag = true; | |||
| auto succ_edge = merge_node->GetAliveSuccEdges()[0]; | |||
| auto target_node = graph->EliminationMerge(merge_node); | |||
| auto elimi = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| } | |||
| auto contracted_node = graph->CheckContractElimination(); | |||
| if ((!flag) && (contracted_node != nullptr)) { | |||
| // Applying the Contract Elimination | |||
| flag = true; | |||
| auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; | |||
| auto target_node = graph->EliminationContract(contracted_node); | |||
| auto elimi = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| if (!flag) { | |||
| auto contracted_node = graph->CheckContractElimination(); | |||
| if ((contracted_node != nullptr)) { | |||
| // Applying the Contract Elimination | |||
| flag = true; | |||
| auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; | |||
| auto target_node = graph->EliminationContract(contracted_node); | |||
| auto elimi = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| } | |||
| auto triangle_pair = graph->CheckTriangleElimination(); | |||
| if ((!flag) && (triangle_pair.first != nullptr)) { | |||
| // Applying the Triangle Elimination | |||
| flag = true; | |||
| auto eliminated_node = triangle_pair.first; | |||
| auto l_r_edge = triangle_pair.second; | |||
| if (!flag) { | |||
| auto triangle_pair = graph->CheckTriangleElimination(); | |||
| if (triangle_pair.first != nullptr) { | |||
| // Applying the Triangle Elimination | |||
| flag = true; | |||
| auto eliminated_node = triangle_pair.first; | |||
| auto l_r_edge = triangle_pair.second; | |||
| auto left_node = l_r_edge->prev_operator(); | |||
| auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; | |||
| auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; | |||
| MS_EXCEPTION_IF_NULL(left_edge); | |||
| if (left_edge->next_operator() != left_node) { | |||
| auto tmp = left_edge; | |||
| left_edge = right_edge; | |||
| right_edge = tmp; | |||
| auto left_node = l_r_edge->prev_operator(); | |||
| auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; | |||
| auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; | |||
| MS_EXCEPTION_IF_NULL(left_edge); | |||
| if (left_edge->next_operator() != left_node) { | |||
| auto tmp = left_edge; | |||
| left_edge = right_edge; | |||
| right_edge = tmp; | |||
| } | |||
| auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); | |||
| auto right_node = l_r_edge->next_operator(); | |||
| auto elimi = | |||
| std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); | |||
| auto right_node = l_r_edge->next_operator(); | |||
| auto elimi = | |||
| std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto star_center = graph->CheckStarElimination(); | |||
| if ((!flag) && (star_center != nullptr)) { | |||
| // Applying the Star Elimination | |||
| flag = true; | |||
| auto succ_edges = graph->EliminationStar(star_center); | |||
| std::vector<OperatorInfoPtr> succ_nodes; | |||
| for (size_t i = 0; i < succ_edges.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(succ_edges[i]); | |||
| succ_nodes.push_back(succ_edges[i]->next_operator()); | |||
| if (!flag) { | |||
| auto star_center = graph->CheckStarElimination(); | |||
| if (star_center != nullptr) { | |||
| // Applying the Star Elimination | |||
| flag = true; | |||
| auto succ_edges = graph->EliminationStar(star_center); | |||
| std::vector<OperatorInfoPtr> succ_nodes; | |||
| for (size_t i = 0; i < succ_edges.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(succ_edges[i]); | |||
| succ_nodes.push_back(succ_edges[i]->next_operator()); | |||
| } | |||
| auto elimi = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto elimi = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace parallel { | |||
| // the operators' strategies can be all determined. | |||
| struct Elimination : public Base { | |||
| enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; | |||
| enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR }; | |||
| Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} | |||
| EdgePtr new_edge_; | |||
| @@ -100,6 +100,26 @@ struct ContractElimination : public Elimination { | |||
| MS_DECLARE_PARENT(ContractElimination, Elimination); | |||
| }; | |||
| // Source Elimination | |||
| struct SourceElimination : public Elimination { | |||
| SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges, | |||
| OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges) | |||
| : Elimination(nullptr, Elimination::EliminationType::SOURCE), | |||
| primary_source_(std::move(p_source)), | |||
| primary_succ_edges_(std::move(p_succ_edges)), | |||
| primary_new_succ_edges_(std::move(p_new_succ_edges)), | |||
| secondary_source_(std::move(s_source)), | |||
| secondary_succ_edges_(std::move(s_succ_edges)), | |||
| secondary_new_succ_edges_(std::move(s_new_succ_edges)) {} | |||
| OperatorInfoPtr primary_source_; | |||
| std::vector<EdgePtr> primary_succ_edges_; | |||
| std::vector<EdgePtr> primary_new_succ_edges_; | |||
| OperatorInfoPtr secondary_source_; | |||
| std::vector<EdgePtr> secondary_succ_edges_; | |||
| std::vector<EdgePtr> secondary_new_succ_edges_; | |||
| MS_DECLARE_PARENT(SourceElimination, Elimination); | |||
| }; | |||
| // Triangle Elimination | |||
| struct TriangleElimination : public Elimination { | |||
| TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, | |||
| @@ -138,6 +158,7 @@ using OpEliminationPtr = std::shared_ptr<OpElimination>; | |||
| using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>; | |||
| using MergeEliminationPtr = std::shared_ptr<MergeElimination>; | |||
| using ContractEliminationPtr = std::shared_ptr<ContractElimination>; | |||
| using SourceEliminationPtr = std::shared_ptr<SourceElimination>; | |||
| using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>; | |||
| using StarEliminationPtr = std::shared_ptr<StarElimination>; | |||
| @@ -320,5 +320,17 @@ Status Edge::CalculateMemoryCostForInference() { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) { | |||
| cost_map_ = cost_map; | |||
| pre_op_output_.clear(); | |||
| next_op_input_.clear(); | |||
| for (auto &key_value : cost_map_) { | |||
| auto &key_pair = key_value.first; | |||
| pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {})); | |||
| next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {})); | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -80,6 +80,8 @@ class Edge { | |||
| std::string edge_name() const { return edge_name_; } | |||
| // Init cost_map_: for each output layout and input layout, calculate the cost | |||
| Status InitEdgeCost(); | |||
| std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; } | |||
| void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &); | |||
| // For two operators u--->v, given the output tensor layout of u, | |||
| // and the input tensor layout of v, return the redistribution cost, | |||
| // and the op_list to carry out the redistribution. | |||
| @@ -794,6 +794,191 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { | |||
| return nullptr; | |||
| } | |||
| std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const { | |||
| size_t source_count = 0; | |||
| std::vector<OperatorInfoPtr> op_vector(2, nullptr); | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0; | |||
| if (bool_test) { | |||
| op_vector[source_count++] = op; | |||
| if (source_count == 2) { | |||
| return std::make_pair(op_vector[0], op_vector[1]); | |||
| } | |||
| } | |||
| } | |||
| return std::make_pair(nullptr, nullptr); | |||
| } | |||
| void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist, | |||
| StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist, | |||
| CostPtrList *op1_new_clist) { | |||
| for (auto &op1_cost : op1_old_clist) { | |||
| for (auto &op2_cost : op2_old_clist) { | |||
| double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_; | |||
| double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_; | |||
| double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_; | |||
| double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_; | |||
| double communication_without_para = | |||
| op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; | |||
| auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost); | |||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||
| MS_EXCEPTION_IF_NULL(new_cost); | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| MS_EXCEPTION_IF_NULL(op1_new_clist); | |||
| op1_new_clist->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| } | |||
| std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources( | |||
| OperatorInfoPtr op1, OperatorInfoPtr op2) { | |||
| MS_EXCEPTION_IF_NULL(op1); | |||
| MS_EXCEPTION_IF_NULL(op2); | |||
| MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name(); | |||
| auto op1_old_succ_edges = op1->GetAliveSuccEdges(); | |||
| std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost( | |||
| op1_old_succ_edges.size()); | |||
| std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size()); | |||
| std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size()); | |||
| auto op2_old_succ_edges = op2->GetAliveSuccEdges(); | |||
| std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost( | |||
| op2_old_succ_edges.size()); | |||
| std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size()); | |||
| std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size()); | |||
| // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost' | |||
| for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { | |||
| const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap(); | |||
| std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost; | |||
| for (const auto &key_value : op1_cost_map) { | |||
| const auto &from_to_strategies = key_value.first; | |||
| const auto &costlist = key_value.second; | |||
| from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); | |||
| } | |||
| op1_edges_reorganised_cost[i] = from_tocost; | |||
| } | |||
| for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { | |||
| const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap(); | |||
| std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost; | |||
| for (const auto &key_value : op2_cost_map) { | |||
| const auto &from_to_strategies = key_value.first; | |||
| const auto &costlist = key_value.second; | |||
| from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist)); | |||
| } | |||
| op2_edges_reorganised_cost[i] = from_tocost; | |||
| } | |||
| // Merge op2 into op1 | |||
| const auto &op1_old_stra_cost = op1->GetStrategyCost(); | |||
| const auto &op2_old_stra_cost = op2->GetStrategyCost(); | |||
| std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost; | |||
| for (auto &op1_stra_cost : op1_old_stra_cost) { | |||
| auto op1_old_stra = op1_stra_cost->strategy_ptr; | |||
| auto op1_old_costlist = op1_stra_cost->cost_list; | |||
| for (auto &op2_stra_cost : op2_old_stra_cost) { | |||
| auto op2_stra = op2_stra_cost->strategy_ptr; | |||
| auto op2_costlist = op2_stra_cost->cost_list; | |||
| StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra); | |||
| op1_new_stra->CoverStrategy(op2_stra); | |||
| CostPtrList op1_new_costlist; | |||
| // Calculate new cost for 'op1_new_costlist' | |||
| CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist); | |||
| std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist); | |||
| op1_new_stra_cost.emplace_back(swc); | |||
| // Set cost for new successive edges of op1 and op2 | |||
| for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { | |||
| auto &from_tocost = op1_edges_reorganised_cost[i]; | |||
| auto &to_cost = from_tocost[op1_old_stra]; | |||
| auto &new_cost_map = op1_new_edges_cost[i]; | |||
| for (auto &stra_costlit : to_cost) { | |||
| auto &to_strategy = stra_costlit.first; | |||
| auto &edge_costlist = stra_costlit.second; | |||
| CostPtrKey new_key = {op1_new_stra, to_strategy}; | |||
| new_cost_map[new_key] = edge_costlist; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { | |||
| auto &from_tocost = op2_edges_reorganised_cost[i]; | |||
| auto &to_cost = from_tocost[op2_stra]; | |||
| auto &new_cost_map = op2_new_edges_cost[i]; | |||
| for (auto &stra_costlist : to_cost) { | |||
| auto &to_strategy = stra_costlist.first; | |||
| auto &edge_costlist = stra_costlist.second; | |||
| CostPtrKey new_key = {op1_new_stra, to_strategy}; | |||
| new_cost_map[new_key] = edge_costlist; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| op1->SetStrategyCost(op1_new_stra_cost); | |||
| op2->SetNotAlive(); | |||
| // Update the edges incident to op1, and edges incident to op2 | |||
| for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { | |||
| auto &new_cost_map = op1_new_edges_cost[i]; | |||
| auto &ith_edge = op1_old_succ_edges[i]; | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); | |||
| ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); | |||
| op1_new_succ_edges[i] = new_edge; | |||
| } | |||
| for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { | |||
| auto &new_cost_map = op2_new_edges_cost[i]; | |||
| auto &ith_edge = op2_old_succ_edges[i]; | |||
| const auto &destination = ith_edge->next_operator(); | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| destination->ReplacePreEdge(op2, new_edge); | |||
| op1->AddSuccEdge(new_edge); | |||
| op2_new_succ_edges[i] = new_edge; | |||
| } | |||
| MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; | |||
| return {op1_new_succ_edges, op2_new_succ_edges}; | |||
| } | |||
| // Check the graph whether a TriangleElimination can be performed | |||
| std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const { | |||
| for (auto &op : ops_) { | |||
| @@ -180,6 +180,14 @@ class CostGraph { | |||
| void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, | |||
| const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>, | |||
| CostPtrList &, CostPtrList &, CostPtrList *); | |||
| // Return <op1, op2>. we merge 'op2' into 'op1' | |||
| std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const; | |||
| void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &, | |||
| CostPtrList *); | |||
| // We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1', | |||
| // 'Edges2' are newly updated edges for 'op2'. | |||
| std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources( | |||
| OperatorInfoPtr op1, OperatorInfoPtr op2); | |||
| // Calculate memory cost for training phase or inference phase. | |||
| Status CalculateMemoryCost(); | |||
| // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | |||
| @@ -1330,5 +1330,9 @@ void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { | |||
| PrintStrategy(s_strategy); | |||
| } | |||
| } | |||
| void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) { | |||
| strategy_cost_ = stra_cost; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -97,6 +97,7 @@ class OperatorInfo { | |||
| // is checked | |||
| Status SetCostUnderStrategyBase(const StrategyPtr &strategy); | |||
| std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | |||
| void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &); | |||
| // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving | |||
| // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory | |||
| // at the end of forward phase. | |||
| @@ -36,7 +36,19 @@ using StrategyPtr = std::shared_ptr<Strategy>; | |||
| class Strategy { | |||
| public: | |||
| Strategy(int32_t stage, std::vector<Dimensions> inputs) : stage_(stage), inputs_(std::move(inputs)) {} | |||
| Strategy(int32_t stage, std::vector<Dimensions> inputs) | |||
| : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {} | |||
| Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) { | |||
| inputs_ = another_stra.GetInputDim(); | |||
| internal_size_ = another_stra.GetInternalSize(); | |||
| if (internal_size_ != 0) { | |||
| internal_stragies_ = another_stra.GetInternalStrategies(); | |||
| } else { | |||
| internal_stragies_ = {}; | |||
| } | |||
| } | |||
| ~Strategy() = default; | |||
| size_t GetInputNumber() const { return inputs_.size(); } | |||
| std::vector<Dimensions> GetInputDim() const { return inputs_; } | |||
| @@ -47,7 +59,10 @@ class Strategy { | |||
| } | |||
| } | |||
| void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; } | |||
| std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; } | |||
| size_t GetInternalSize() const { return internal_size_; } | |||
| // TODO(Xiaoda): need fix for adapting 'CoverStrategy' | |||
| bool IsEqual(const StrategyPtr &another_stra) { | |||
| if (another_stra == nullptr) { | |||
| return false; | |||
| @@ -58,11 +73,19 @@ class Strategy { | |||
| return true; | |||
| } | |||
| // Include 'another_stra' into this strategy | |||
| void CoverStrategy(const StrategyPtr &another_stra) { | |||
| internal_stragies_.push_back(another_stra); | |||
| internal_size_++; | |||
| } | |||
| private: | |||
| const int32_t stage_; | |||
| // The size of Dimensions must equal to inputs_ tensor dimension. | |||
| std::vector<Dimensions> inputs_; | |||
| size_t internal_size_ = 0; | |||
| std::vector<StrategyPtr> internal_stragies_; | |||
| }; | |||
| inline StrategyPtr NewStrategy(const int32_t stage, const std::vector<Dimensions> &inputs) { | |||
| @@ -0,0 +1,114 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y, z, w, a): | |||
| predict = self.network(x, y, z, w, a) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y, z, w, a): | |||
| return C.grad_all(self.network)(x, y, z, w, a) | |||
| # model_parallel test | |||
| def test_double_source_graph(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul() | |||
| self.matmul2 = P.MatMul() | |||
| self.matmul3 = P.MatMul() | |||
| self.matmul4 = P.MatMul() | |||
| self.matmul5 = P.MatMul() | |||
| def construct(self, x, y, z, w, a): | |||
| m1_result = self.matmul1(x, y) | |||
| m2_result = self.matmul2(z, w) | |||
| m3_result = self.matmul3(m2_result, m1_result) | |||
| m4_result = self.matmul4(m2_result, m1_result) | |||
| out = self.matmul5(m3_result, m4_result) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| z = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| w = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| a = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y, z, w, a) | |||
| def test_double_source_complex_graph(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul() | |||
| self.matmul2 = P.MatMul() | |||
| self.matmul3 = P.MatMul() | |||
| self.matmul4 = P.MatMul() | |||
| self.matmul5 = P.MatMul() | |||
| self.matmul6 = P.MatMul() | |||
| def construct(self, x, y, z, w, a): | |||
| m1_result = self.matmul1(x, y) | |||
| m6_result = self.matmul6(m1_result, a) | |||
| m2_result = self.matmul2(z, w) | |||
| m3_result = self.matmul3(m2_result, m6_result) | |||
| m4_result = self.matmul4(m2_result, m1_result) | |||
| out = self.matmul5(m3_result, m4_result) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| z = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| w = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| a = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y, z, w, a) | |||