| @@ -529,8 +529,11 @@ bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra | |||
| "different sharding strategies. These operators are: "; | |||
| auto const &succ_edges = prev_op_->succ_edges(); | |||
| for (auto const &succ_edge : succ_edges) { | |||
| if (succ_edge->next_operator()->cnodes().empty()) { | |||
| MS_LOG(EXCEPTION) << "No CNODE info has been set in operator: " << succ_edge->next_operator()->name(); | |||
| } | |||
| MS_LOG(ERROR) << succ_edge->next_operator()->name() << ", the corresponding fullname is: " | |||
| << succ_edge->next_operator()->cnode()->fullname_with_scope(); | |||
| << succ_edge->next_operator()->cnodes()[0]->fullname_with_scope(); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Configure these operators with consistent sharding strategies."; | |||
| } | |||
| @@ -1770,10 +1770,7 @@ void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) { | |||
| swc_index_ = swc; | |||
| } | |||
| CNodePtr OperatorInfo::cnode() { | |||
| MS_EXCEPTION_IF_NULL(cnode_); | |||
| return cnode_; | |||
| } | |||
| std::vector<CNodePtr> OperatorInfo::cnodes() { return cnodes_; } | |||
| double OperatorInfo::GetForwardMemoryCostFromCNode() { | |||
| return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||
| @@ -167,8 +167,11 @@ class OperatorInfo { | |||
| void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; } | |||
| const std::vector<ValuePtr> &input_value() const { return input_value_; } | |||
| void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } | |||
| void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } | |||
| CNodePtr cnode(); | |||
| void set_cnode(const CNodePtr &cnode) { | |||
| cnode_ = cnode; | |||
| cnodes_.push_back(cnode); | |||
| } | |||
| std::vector<CNodePtr> cnodes(); | |||
| bool is_alive() const { return is_alive_; } | |||
| void SetNotAlive() { is_alive_ = false; } | |||
| StrategyPtr strategy() const { return strategy_; } | |||
| @@ -299,6 +302,7 @@ class OperatorInfo { | |||
| std::vector<bool> split_flag_list_; | |||
| std::string refkey_parameter_name_; | |||
| CNodePtr cnode_; | |||
| std::vector<CNodePtr> cnodes_; | |||
| int64_t used_devices_ = -1; | |||
| // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default | |||
| bool repeated_num_in_dev_matrix_right_ = true; | |||
| @@ -233,6 +233,7 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn | |||
| // 'configured_stra_ops_' includes all operators that are configured sharding strategies. | |||
| std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_; | |||
| std::set<OperatorInfoPtr> ignore_candidate_; | |||
| void InitCostGraph() { | |||
| if (entire_costgraph == nullptr) { | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| @@ -241,6 +242,7 @@ void InitCostGraph() { | |||
| CostModelContext::GetInstance()->PrintCostModel(); | |||
| entire_costgraph->Init(); | |||
| configured_stra_ops_.clear(); | |||
| ignore_candidate_.clear(); | |||
| } | |||
| void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim, | |||
| @@ -297,6 +299,13 @@ void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) { | |||
| } | |||
| } | |||
| void AddOperatorToIgnoreCandidates(const PrimitivePtr &prim, const OperatorInfoPtr &operator_info) { | |||
| if (prim->name() == CAST) { | |||
| // add CAST into ignore_candidate | |||
| ignore_candidate_.insert(operator_info); | |||
| } | |||
| } | |||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, | |||
| StrategyMap *stra_map) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| @@ -344,6 +353,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| operator_info->set_input_value(input_value); | |||
| operator_info->set_outputs_dtype(cnode->Type()); | |||
| operator_info->set_cnode(cnode); | |||
| AddOperatorToIgnoreCandidates(prim, operator_info); | |||
| // key of strategy map | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||
| @@ -968,6 +979,17 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| Status IgnoreOperatorsInCostGraph() { | |||
| for (const auto &op : ignore_candidate_) { | |||
| auto cnodes = op->cnodes(); | |||
| for (auto &cnode : cnodes) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| cnode->set_user_data<OperatorInfo>(nullptr); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | |||
| // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. | |||
| // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | |||
| @@ -1035,7 +1057,6 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| (ParallelContext::GetInstance()->sharding_propagation()); | |||
| if (use_sp) { | |||
| entire_costgraph->StrategyPropagate(configured_stra_ops_); | |||
| configured_stra_ops_.clear(); | |||
| } else if (GetStrategy(entire_costgraph) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Strategy search for cost-graph fails"; | |||
| return FAILED; | |||
| @@ -1054,7 +1075,12 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| MS_LOG(INFO) << op->name() << " : The strategy is:"; | |||
| PrintStrategy(s_strategy); | |||
| } | |||
| // Remove some operatorInfo from the CNODEs | |||
| IgnoreOperatorsInCostGraph(); | |||
| ops_in_a_loop_.clear(); | |||
| configured_stra_ops_.clear(); | |||
| ignore_candidate_.clear(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -47,6 +47,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | |||
| void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes); | |||
| Status IgnoreOperatorsInCostGraph(); | |||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| @@ -84,9 +84,7 @@ def test_double_star_graph(): | |||
| net.set_train() | |||
| _cell_graph_executor.compile(net, x, y, z, w, phase='train') | |||
| strategies = _cell_graph_executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]], | |||
| 'Default/network-Net/Cast-op3': [[1, 8]], | |||
| 'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]], | |||
| expected_strategies = {'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]], | |||
| 'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]} | |||
| assert strategies == expected_strategies | |||