| @@ -45,7 +45,7 @@ bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) | |||||
| [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); | ||||
| } | } | ||||
| void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn, | |||||
| void ProcessThroughPassCNode(const std::function<bool(const AnfNodePtr &)> &pass_fn, | |||||
| OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | ||||
| std::set<AnfNodePtr> latter_to_be_erased; | std::set<AnfNodePtr> latter_to_be_erased; | ||||
| for (const auto &[node, node_rel] : (*node_rels)) { | for (const auto &[node, node_rel] : (*node_rels)) { | ||||
| @@ -464,7 +464,7 @@ std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodes | |||||
| } | } | ||||
| AnfNodePtrList target_nodes = {nodes[start]}; | AnfNodePtrList target_nodes = {nodes[start]}; | ||||
| std::vector<int> valid_indices; | std::vector<int> valid_indices; | ||||
| std::vector<int> unused; | |||||
| std::vector<size_t> unused; | |||||
| for (size_t i = start; i < used.size(); ++i) { | for (size_t i = start; i < used.size(); ++i) { | ||||
| if (!used[i] && excludes.count(i) == 0) { | if (!used[i] && excludes.count(i) == 0) { | ||||
| unused.push_back(i); | unused.push_back(i); | ||||
| @@ -475,7 +475,7 @@ std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodes | |||||
| if (offset >= limit) { | if (offset >= limit) { | ||||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes."; | MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes."; | ||||
| } | } | ||||
| if (unused[offset] >= node_limit) { | |||||
| if (SizeToInt(unused[offset]) >= node_limit) { | |||||
| MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes."; | MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes."; | ||||
| } | } | ||||
| valid_indices.push_back(unused[offset]); | valid_indices.push_back(unused[offset]); | ||||
| @@ -507,7 +507,6 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| int max_benefit = 0; | int max_benefit = 0; | ||||
| ParallelInfo best_parallel_info; | ParallelInfo best_parallel_info; | ||||
| std::set<int> bad_set; | |||||
| size_t unused_num = 0; | size_t unused_num = 0; | ||||
| for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) { | for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) { | ||||
| unused_num += sorted_candidates_used[j] ? 0 : 1; | unused_num += sorted_candidates_used[j] ? 0 : 1; | ||||
| @@ -525,7 +524,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| std::iota(tc.begin(), tc.end(), 1); | std::iota(tc.begin(), tc.end(), 1); | ||||
| AnfNodePtrList other_candidates; | AnfNodePtrList other_candidates; | ||||
| std::tie(other_candidates, std::ignore) = | std::tie(other_candidates, std::ignore) = | ||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| int benefit; | int benefit; | ||||
| std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); | std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates); | ||||
| if (benefit > 0) { | if (benefit > 0) { | ||||
| @@ -540,7 +539,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| std::iota(tc.begin(), tc.end(), 1); | std::iota(tc.begin(), tc.end(), 1); | ||||
| AnfNodePtrList other_candidates; | AnfNodePtrList other_candidates; | ||||
| std::tie(other_candidates, std::ignore) = | std::tie(other_candidates, std::ignore) = | ||||
| GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>()); | |||||
| auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates); | auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates); | ||||
| if (benefit <= 0) { | if (benefit <= 0) { | ||||
| MS_LOG(EXCEPTION) << "Internal error in candidate search!"; | MS_LOG(EXCEPTION) << "Internal error in candidate search!"; | ||||
| @@ -553,8 +552,8 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSea | |||||
| if (max_benefit > 0) { | if (max_benefit > 0) { | ||||
| parallel_infos.push_back(best_parallel_info); | parallel_infos.push_back(best_parallel_info); | ||||
| for (const auto &node : best_parallel_info.nodes()) { | for (const auto &node : best_parallel_info.nodes()) { | ||||
| sorted_candidates_used[get_index(sorted_indices, node)] = true; | |||||
| origin_candidates_used[get_index(origin_indices, node)] = true; | |||||
| sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true; | |||||
| origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -593,7 +592,7 @@ std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::Searc | |||||
| std::map<AnfNodePtr, int> sorted_indices; | std::map<AnfNodePtr, int> sorted_indices; | ||||
| for (size_t i = 0; i < candidates.size(); ++i) { | for (size_t i = 0; i < candidates.size(); ++i) { | ||||
| sorted_indices.insert({candidates[i], i}); | |||||
| sorted_indices.emplace(candidates[i], i); | |||||
| } | } | ||||
| return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices); | return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices); | ||||
| @@ -620,7 +619,7 @@ void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodeP | |||||
| } | } | ||||
| for (size_t id = 0; id < used.size(); ++id) { | for (size_t id = 0; id < used.size(); ++id) { | ||||
| if (used[id]) { | if (used[id]) { | ||||
| tails[id]++; | |||||
| ++tails[id]; | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||