|
|
|
@@ -23,42 +23,8 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4}; |
|
|
|
constexpr size_t kBatchNormLeastOutputNum = 1; |
|
|
|
constexpr size_t kBatchNormRealInputNum = 3; |
|
|
|
|
|
|
|
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(bn_outputs); |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto iter = manager->node_users().find(bn); |
|
|
|
if (iter == manager->node_users().end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t output_num = 0; |
|
|
|
for (const auto &node_index : iter->second) { |
|
|
|
AnfNodePtr output = node_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto tuple_getiterm_cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); |
|
|
|
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); |
|
|
|
MS_EXCEPTION_IF_NULL(index_node); |
|
|
|
auto value_node = index_node->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
int index = GetValue<int>(value_node->value()); |
|
|
|
if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
bn_outputs->push_back(output); |
|
|
|
output_num++; |
|
|
|
} |
|
|
|
return output_num >= kBatchNormLeastOutputNum; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
@@ -140,34 +106,12 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, |
|
|
|
MS_LOG(INFO) << "is training should be true if do fusion"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> bn_outputs; |
|
|
|
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { |
|
|
|
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); |
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, |
|
|
|
&bn_training_reduce_outputs); |
|
|
|
|
|
|
|
AnfNodePtr bn_training_update_v3 = CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); |
|
|
|
std::vector<AnfNodePtr> bn_training_update_v3_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v3, kBNTrainingUpdateV3OutputNum, |
|
|
|
&bn_training_update_v3_outputs); |
|
|
|
if (bn_training_update_v3_outputs.size() != kBNTrainingUpdateV3OutputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum |
|
|
|
<< ", but it is " << bn_training_update_v3_outputs.size(); |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); |
|
|
|
size_t output_index = 0; |
|
|
|
for (const auto &output : bn_outputs) { |
|
|
|
(void)manager->Replace(output, bn_training_update_v3_outputs[output_index]); |
|
|
|
output_index++; |
|
|
|
} |
|
|
|
// Return the new node for control depends. |
|
|
|
return bn_training_update_v3; |
|
|
|
return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |