|
|
|
@@ -25,6 +25,7 @@ namespace opt { |
|
|
|
namespace { |
|
|
|
const std::vector<int> kOutputIndex{0, 3, 4, 5}; |
|
|
|
constexpr size_t kBatchNormRealOutputNum = 3; |
|
|
|
constexpr size_t kBatchNormRealInputNum = 3; |
|
|
|
|
|
|
|
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { |
|
|
|
MS_EXCEPTION_IF_NULL(n1); |
|
|
|
@@ -56,6 +57,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s |
|
|
|
for (const auto &node_index : manager->node_users()[bn]) { |
|
|
|
AnfNodePtr output = node_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto tuple_getiterm_cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); |
|
|
|
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); |
|
|
|
@@ -77,7 +81,10 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP |
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
auto bn_cnode = bn->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode); |
|
|
|
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1); |
|
|
|
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { |
|
|
|
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " |
|
|
|
<< kBatchNormRealInputNum + 1; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)}; |
|
|
|
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); |
|
|
|
@@ -100,7 +107,10 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod |
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
auto bn_cnode = bn->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode); |
|
|
|
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1); |
|
|
|
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { |
|
|
|
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " |
|
|
|
<< kBatchNormRealInputNum + 1; |
|
|
|
} |
|
|
|
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum |
|
|
|
<< ", but it is " << bn_training_reduce_outputs.size(); |
|
|
|
@@ -164,7 +174,8 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c |
|
|
|
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); |
|
|
|
output_index++; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
// Return the new node for control depends. |
|
|
|
return bn_training_update_v2; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |