|
|
|
@@ -28,14 +28,14 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, |
|
|
|
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, |
|
|
|
std::vector<AnfNodePtr> *bn_training_reduce_outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode); |
|
|
|
if (bn_cnode->inputs().size() != kBnInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "BN node has wrong input size"; |
|
|
|
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
// All the inputs of BNTrainingReduce are from the inputs of BN |
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))}; |
|
|
|
bn_training_reduce_inputs.push_back(bn_cnode->input(1)); |
|
|
|
@@ -45,8 +45,9 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
bn_training_reduce->set_kernel_info(kernel_info); |
|
|
|
std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); |
|
|
|
if (bn_shape_i0.size() != kShape4dDims) { |
|
|
|
MS_LOG(EXCEPTION) << "Get shape of FusedBatchNorm fail"; |
|
|
|
if (bn_shape_i0.size() < kShape2dDims) { |
|
|
|
MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]}; |
|
|
|
auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; |
|
|
|
@@ -56,6 +57,7 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & |
|
|
|
AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); |
|
|
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, |
|
|
|
@@ -99,11 +101,15 @@ AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNo |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->inputs().size() < kBnInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; |
|
|
|
MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// Create BNTrainingReduce node and get outputs of BNTrainingReduce |
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_outputs; |
|
|
|
CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs); |
|
|
|
if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { |
|
|
|
MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (bn_training_reduce_outputs.size() != kBN1OutputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; |
|
|
|
} |
|
|
|
|