|
|
|
@@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { |
|
|
|
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto tuple_getitem = node->cast<CNodePtr>(); |
|
|
|
@@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
MS_EXCEPTION_IF_NULL(batchnormgrad_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(batchnormgrad); |
|
|
|
*batchnormgrad = batchnormgrad_anf->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(*batchnormgrad); |
|
|
|
return CheckBatchNormGrad(graph, *batchnormgrad); |
|
|
|
AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
MS_EXCEPTION_IF_NULL(batchnorm_grad_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(batchnorm_grad); |
|
|
|
*batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(*batchnorm_grad); |
|
|
|
return CheckBatchNormGrad(graph, *batchnorm_grad); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
@@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
CNodePtr batchnormgrad = nullptr; |
|
|
|
if (!NeedFusion(graph, node, &batchnormgrad)) { |
|
|
|
CNodePtr batchnorm_grad = nullptr; |
|
|
|
if (!NeedFusion(graph, node, &batchnorm_grad)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return CreateBNInferGrad(graph, batchnormgrad, node); |
|
|
|
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); |
|
|
|
TransferDepend(batchnorm_grad, graph, bn_infer_grad); |
|
|
|
return bn_infer_grad; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |