|
|
|
@@ -72,6 +72,38 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An |
|
|
|
} |
|
|
|
return mul0; |
|
|
|
} |
|
|
|
|
|
|
|
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(mul0_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum); |
|
|
|
if (!mul0_anf->isa<CNode>()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto mul0 = mul0_anf->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(mul0); |
|
|
|
|
|
|
|
// when network is _VirtualDatasetCell, quit fusion |
|
|
|
if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
if (manager->node_users().find(reduce_sum) == manager->node_users().end()) { |
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager"; |
|
|
|
} |
|
|
|
const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum]; |
|
|
|
auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair<AnfNodePtr, int> &node_index) { |
|
|
|
return node_index.first == mul0->input(1) || node_index.first == mul0; |
|
|
|
}); |
|
|
|
if (it != outputs_set.end()) { |
|
|
|
MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef ConfusionMulGradFusion::DefinePattern() const { |
|
|
|
@@ -90,9 +122,6 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons |
|
|
|
auto reduce_sum = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum); |
|
|
|
auto mul1 = reduce_sum->input(1); |
|
|
|
if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (IsUsedByOthers(graph, mul1)) { |
|
|
|
MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; |
|
|
|
return nullptr; |
|
|
|
@@ -102,6 +131,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons |
|
|
|
MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (QuitFusion(graph, mul0, node)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); |
|
|
|
std::vector<AnfNodePtr> fusion_node_outputs; |
|
|
|
|