|
|
|
@@ -140,6 +140,9 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, |
|
|
|
// process pattern as Relu(TensorAdd(BN#0, BN#1))
|
|
|
|
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
|
|
|
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
|
|
|
return nullptr;
|
|
|
|
|