| @@ -23,42 +23,8 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| const std::vector<int> kOutputIndex{0, 1, 2, 3, 4}; | |||
| constexpr size_t kBatchNormLeastOutputNum = 1; | |||
| constexpr size_t kBatchNormRealInputNum = 3; | |||
| bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_outputs); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto iter = manager->node_users().find(bn); | |||
| if (iter == manager->node_users().end()) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | |||
| auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| auto value_node = index_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| int index = GetValue<int>(value_node->value()); | |||
| if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { | |||
| return false; | |||
| } | |||
| bn_outputs->push_back(output); | |||
| output_num++; | |||
| } | |||
| return output_num >= kBatchNormLeastOutputNum; | |||
| } | |||
| AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| @@ -140,34 +106,12 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, | |||
| MS_LOG(INFO) << "is training should be true if do fusion"; | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> bn_outputs; | |||
| if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { | |||
| MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); | |||
| std::vector<AnfNodePtr> bn_training_reduce_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, | |||
| &bn_training_reduce_outputs); | |||
| AnfNodePtr bn_training_update_v3 = CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); | |||
| std::vector<AnfNodePtr> bn_training_update_v3_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v3, kBNTrainingUpdateV3OutputNum, | |||
| &bn_training_update_v3_outputs); | |||
| if (bn_training_update_v3_outputs.size() != kBNTrainingUpdateV3OutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum | |||
| << ", but it is " << bn_training_update_v3_outputs.size(); | |||
| } | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); | |||
| size_t output_index = 0; | |||
| for (const auto &output : bn_outputs) { | |||
| (void)manager->Replace(output, bn_training_update_v3_outputs[output_index]); | |||
| output_index++; | |||
| } | |||
| // Return the new node for control depends. | |||
| return bn_training_update_v3; | |||
| return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||