|
|
@@ -23,6 +23,8 @@ |
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
namespace { |
|
|
namespace { |
|
|
|
|
|
constexpr size_t kReplaceOutputIndex0 = 3; |
|
|
|
|
|
constexpr size_t kReplaceOutputIndex1 = 4; |
|
|
bool IsC(const BaseRef &n) { |
|
|
bool IsC(const BaseRef &n) { |
|
|
if (utils::isa<AnfNodePtr>(n)) { |
|
|
if (utils::isa<AnfNodePtr>(n)) { |
|
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n); |
|
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n); |
|
|
@@ -32,52 +34,6 @@ bool IsC(const BaseRef &n) { |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
AnfNodePtr GetBatchNormNode(const AnfNodePtr &node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
auto depend_cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depend_cnode); |
|
|
|
|
|
CheckCNodeInputSize(depend_cnode, kDependInputNum); |
|
|
|
|
|
AnfNodePtr assign_sub = depend_cnode->input(2); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(assign_sub); |
|
|
|
|
|
auto assign_sub_cnode = assign_sub->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(assign_sub_cnode); |
|
|
|
|
|
CheckCNodeInputSize(assign_sub_cnode, kAssignSubInputNum); |
|
|
|
|
|
AnfNodePtr mul = assign_sub_cnode->input(2); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul); |
|
|
|
|
|
auto mul_cnode = mul->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul_cnode); |
|
|
|
|
|
CheckCNodeInputSize(mul_cnode, kMulInputNum); |
|
|
|
|
|
AnfNodePtr sub = mul_cnode->input(1); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub); |
|
|
|
|
|
auto sub_cnode = sub->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(sub_cnode); |
|
|
|
|
|
CheckCNodeInputSize(sub_cnode, kSubInputNum); |
|
|
|
|
|
AnfNodePtr tuple_getitem = sub_cnode->input(2); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
|
|
auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); |
|
|
|
|
|
CheckCNodeInputSize(tuple_getitem_cnode, kTupleGetitemInputNum); |
|
|
|
|
|
return tuple_getitem_cnode->input(1); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(n1); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(n2); |
|
|
|
|
|
auto n1_cnode = n1->cast<CNodePtr>(); |
|
|
|
|
|
auto n2_cnode = n2->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(n1_cnode); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(n2_cnode); |
|
|
|
|
|
auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(index_input1); |
|
|
|
|
|
auto value_node1 = index_input1->cast<ValueNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node1); |
|
|
|
|
|
auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(index_input2); |
|
|
|
|
|
auto value_node2 = index_input2->cast<ValueNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node2); |
|
|
|
|
|
return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { |
|
|
void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
@@ -92,54 +48,35 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
bn_outputs->push_back(output); |
|
|
bn_outputs->push_back(output); |
|
|
} |
|
|
} |
|
|
sort(bn_outputs->begin(), bn_outputs->end(), CompareTupleGetitem); |
|
|
|
|
|
} |
|
|
} |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
const BaseRef FusedBatchNormFusion::DefinePattern() const { |
|
|
const BaseRef FusedBatchNormFusion::DefinePattern() const { |
|
|
const auto prim_batch_norm = std::make_shared<Primitive>(kBatchNormOpName); |
|
|
|
|
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); |
|
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); |
|
|
VarPtr index0 = std::make_shared<CondVar>(IsC); |
|
|
VarPtr index0 = std::make_shared<CondVar>(IsC); |
|
|
VarPtr index1 = std::make_shared<CondVar>(IsC); |
|
|
VarPtr index1 = std::make_shared<CondVar>(IsC); |
|
|
VarPtr index2 = std::make_shared<CondVar>(IsC); |
|
|
VarPtr index2 = std::make_shared<CondVar>(IsC); |
|
|
VectorRef batch_norm = VectorRef({prim_batch_norm, data_input_var0_, data_input_var1_, data_input_var2_, Xs}); |
|
|
|
|
|
|
|
|
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); |
|
|
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); |
|
|
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); |
|
|
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); |
|
|
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); |
|
|
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); |
|
|
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); |
|
|
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input_var0_, tuple_getitem1}); |
|
|
|
|
|
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input_var1_, tuple_getitem2}); |
|
|
|
|
|
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input_var0_}); |
|
|
|
|
|
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input_var1_}); |
|
|
|
|
|
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input_var0_, mul0}); |
|
|
|
|
|
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input_var1_, mul1}); |
|
|
|
|
|
|
|
|
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); |
|
|
|
|
|
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); |
|
|
|
|
|
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); |
|
|
|
|
|
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); |
|
|
|
|
|
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); |
|
|
|
|
|
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); |
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); |
|
|
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); |
|
|
return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); |
|
|
return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
abstract::AbstractTuplePtr FusedBatchNormFusion::CreateAbstractOfFusedBatchNorm(const EquivPtr &equiv, |
|
|
|
|
|
const AnfNodePtr &bn) const { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
|
|
auto variable_input0 = utils::cast<AnfNodePtr>((*equiv)[variable_input_var0_]); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(variable_input0); |
|
|
|
|
|
auto variable_input1 = utils::cast<AnfNodePtr>((*equiv)[variable_input_var1_]); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(variable_input1); |
|
|
|
|
|
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract()); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_abstract_tuple); |
|
|
|
|
|
if (bn_abstract_tuple->elements().size() != kBnOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is " |
|
|
|
|
|
<< bn_abstract_tuple->elements().size(); |
|
|
|
|
|
} |
|
|
|
|
|
AbstractBasePtrList fused_bn_abstract_list{bn_abstract_tuple->elements()[0], variable_input0->abstract(), |
|
|
|
|
|
variable_input1->abstract(), bn_abstract_tuple->elements()[3], |
|
|
|
|
|
bn_abstract_tuple->elements()[4]}; |
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(fused_bn_abstract_list); |
|
|
|
|
|
return abstract_tuple; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { |
|
|
ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
auto constant_input = utils::cast<AnfNodePtr>((*equiv)[constant_input_var0_]); |
|
|
|
|
|
|
|
|
auto iter_constant_input0 = (*equiv).find(constant_input0_var_); |
|
|
|
|
|
if (iter_constant_input0 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto constant_input = utils::cast<AnfNodePtr>(iter_constant_input0->second); |
|
|
MS_EXCEPTION_IF_NULL(constant_input); |
|
|
MS_EXCEPTION_IF_NULL(constant_input); |
|
|
if (!constant_input->isa<ValueNode>()) { |
|
|
if (!constant_input->isa<ValueNode>()) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
@@ -158,53 +95,187 @@ ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { |
|
|
return MakeValue(tensor_data[0]); |
|
|
return MakeValue(tensor_data[0]); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
|
|
const EquivPtr &equiv) const { |
|
|
|
|
|
|
|
|
AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
|
|
const EquivPtr &equiv) const { |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
// Set inputs |
|
|
|
|
|
auto data_input0 = utils::cast<AnfNodePtr>((*equiv)[data_input_var0_]); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data_input0); |
|
|
|
|
|
auto data_input1 = utils::cast<AnfNodePtr>((*equiv)[data_input_var1_]); |
|
|
|
|
|
|
|
|
// Set input to create node |
|
|
|
|
|
auto iter_data_input0 = (*equiv).find(data_input0_var_); |
|
|
|
|
|
if (iter_data_input0 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = { |
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_data_input0->second)}; |
|
|
|
|
|
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_training_reduce); |
|
|
|
|
|
bn_training_reduce->set_scope(node->scope()); |
|
|
|
|
|
// Set abstract |
|
|
|
|
|
auto iter_data_input1 = (*equiv).find(data_input1_var_); |
|
|
|
|
|
if (iter_data_input1 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto data_input1 = utils::cast<AnfNodePtr>(iter_data_input1->second); |
|
|
MS_EXCEPTION_IF_NULL(data_input1); |
|
|
MS_EXCEPTION_IF_NULL(data_input1); |
|
|
auto data_input2 = utils::cast<AnfNodePtr>((*equiv)[data_input_var2_]); |
|
|
|
|
|
|
|
|
auto iter_data_input2 = (*equiv).find(data_input2_var_); |
|
|
|
|
|
if (iter_data_input2 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto data_input2 = utils::cast<AnfNodePtr>(iter_data_input2->second); |
|
|
MS_EXCEPTION_IF_NULL(data_input2); |
|
|
MS_EXCEPTION_IF_NULL(data_input2); |
|
|
auto variable_input0 = utils::cast<AnfNodePtr>((*equiv)[variable_input_var0_]); |
|
|
|
|
|
|
|
|
AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; |
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); |
|
|
|
|
|
bn_training_reduce->set_abstract(abstract_tuple); |
|
|
|
|
|
return bn_training_reduce; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, |
|
|
|
|
|
const std::vector<AnfNodePtr> &bn_training_reduce_outputs, |
|
|
|
|
|
std::vector<AnfNodePtr> *bn_training_update_inputs) const { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_training_update_inputs); |
|
|
|
|
|
auto iter_data_input0 = (*equiv).find(data_input0_var_); |
|
|
|
|
|
if (iter_data_input0 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto iter_data_input1 = (*equiv).find(data_input1_var_); |
|
|
|
|
|
if (iter_data_input1 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto iter_data_input2 = (*equiv).find(data_input2_var_); |
|
|
|
|
|
if (iter_data_input2 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto iter_variable_input0 = (*equiv).find(variable_input0_var_); |
|
|
|
|
|
if (iter_variable_input0 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto iter_variable_input1 = (*equiv).find(variable_input1_var_); |
|
|
|
|
|
if (iter_variable_input1 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum |
|
|
|
|
|
<< ", but it is " << bn_training_reduce_outputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
*bn_training_update_inputs = { |
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)), |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_data_input0->second), |
|
|
|
|
|
bn_training_reduce_outputs[0], |
|
|
|
|
|
bn_training_reduce_outputs[1], |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_data_input1->second), |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_data_input2->second), |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_variable_input0->second), |
|
|
|
|
|
utils::cast<AnfNodePtr>(iter_variable_input1->second), |
|
|
|
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, |
|
|
|
|
|
std::vector<AbstractBasePtr> *abstract_list) const { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract_list); |
|
|
|
|
|
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract()); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_abstract_tuple); |
|
|
|
|
|
if (bn_abstract_tuple->elements().size() < kBnOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " |
|
|
|
|
|
<< bn_abstract_tuple->elements().size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto iter_variable_input0 = (*equiv).find(variable_input0_var_); |
|
|
|
|
|
if (iter_variable_input0 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto variable_input0 = utils::cast<AnfNodePtr>(iter_variable_input0->second); |
|
|
MS_EXCEPTION_IF_NULL(variable_input0); |
|
|
MS_EXCEPTION_IF_NULL(variable_input0); |
|
|
auto variable_input1 = utils::cast<AnfNodePtr>((*equiv)[variable_input_var1_]); |
|
|
|
|
|
|
|
|
auto iter_variable_input1 = (*equiv).find(variable_input1_var_); |
|
|
|
|
|
if (iter_variable_input1 == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
auto variable_input1 = utils::cast<AnfNodePtr>(iter_variable_input1->second); |
|
|
MS_EXCEPTION_IF_NULL(variable_input1); |
|
|
MS_EXCEPTION_IF_NULL(variable_input1); |
|
|
std::vector<AnfNodePtr> fused_bn_inputs = { |
|
|
|
|
|
NewValueNode(prim::kPrimFusedBatchNorm), data_input0, data_input1, data_input2, variable_input0, variable_input1}; |
|
|
|
|
|
auto fused_bn = func_graph->NewCNode(fused_bn_inputs); |
|
|
|
|
|
fused_bn->set_scope(node->scope()); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_bn); |
|
|
|
|
|
|
|
|
*abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), |
|
|
|
|
|
bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( |
|
|
|
|
|
const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, |
|
|
|
|
|
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
|
|
|
// Set input |
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_update_inputs; |
|
|
|
|
|
GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs); |
|
|
|
|
|
auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_training_update); |
|
|
// Set abstract |
|
|
// Set abstract |
|
|
AnfNodePtr bn = GetBatchNormNode(node); |
|
|
|
|
|
fused_bn->set_abstract(CreateAbstractOfFusedBatchNorm(equiv, bn)); |
|
|
|
|
|
// Set attr |
|
|
|
|
|
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, fused_bn); |
|
|
|
|
|
|
|
|
auto iter_batch_norm = (*equiv).find(batch_norm_var_); |
|
|
|
|
|
if (iter_batch_norm == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; |
|
|
|
|
|
} |
|
|
|
|
|
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn); |
|
|
|
|
|
AbstractBasePtrList abstract_list; |
|
|
|
|
|
GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); |
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); |
|
|
|
|
|
bn_training_update->set_abstract(abstract_tuple); |
|
|
|
|
|
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update); |
|
|
ValuePtr factor = GetFactor(equiv); |
|
|
ValuePtr factor = GetFactor(equiv); |
|
|
if (factor == nullptr) { |
|
|
if (factor == nullptr) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
AnfAlgo::SetNodeAttr(kAttrMomentum, factor, fused_bn); |
|
|
|
|
|
// Replace old nodes with outputs of fused_bn |
|
|
|
|
|
std::vector<AnfNodePtr> fused_bn_outputs; |
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, fused_bn, kBnOutputNum, &fused_bn_outputs); |
|
|
|
|
|
if (fused_bn_outputs.size() != kBnOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBnOutputNum << ", but it is " |
|
|
|
|
|
<< fused_bn_outputs.size(); |
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update); |
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); |
|
|
|
|
|
bn_training_update->set_scope(node->scope()); |
|
|
|
|
|
return bn_training_update; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
|
|
const EquivPtr &equiv) const { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv); |
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_outputs; |
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, |
|
|
|
|
|
&bn_training_reduce_outputs); |
|
|
|
|
|
AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs); |
|
|
|
|
|
if (bn_training_update == nullptr) { |
|
|
|
|
|
MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString(); |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_update_outputs; |
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum, |
|
|
|
|
|
&bn_training_update_outputs); |
|
|
|
|
|
if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is " |
|
|
|
|
|
<< bn_training_update_outputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
// Replace old bn outputs with new outputs |
|
|
|
|
|
auto iter_batch_norm = (*equiv).find(batch_norm_var_); |
|
|
|
|
|
if (iter_batch_norm == (*equiv).end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; |
|
|
} |
|
|
} |
|
|
|
|
|
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second); |
|
|
std::vector<AnfNodePtr> bn_outputs; |
|
|
std::vector<AnfNodePtr> bn_outputs; |
|
|
GetBNOutput(func_graph, bn, &bn_outputs); |
|
|
GetBNOutput(func_graph, bn, &bn_outputs); |
|
|
if (bn_outputs.size() != kBnOutputNum) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBnOutputNum << ", but it is " << bn_outputs.size(); |
|
|
|
|
|
} |
|
|
|
|
|
auto manager = func_graph->manager(); |
|
|
auto manager = func_graph->manager(); |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
(void)manager->Replace(bn_outputs[3], fused_bn_outputs[3]); |
|
|
|
|
|
(void)manager->Replace(bn_outputs[4], fused_bn_outputs[4]); |
|
|
|
|
|
return fused_bn_outputs[0]; |
|
|
|
|
|
|
|
|
for (const auto &output : bn_outputs) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
|
|
auto tuple_getitem_cnode = output->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); |
|
|
|
|
|
AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(index_node); |
|
|
|
|
|
auto value_node = index_node->cast<ValueNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
|
|
int index = GetValue<int>(value_node->value()); |
|
|
|
|
|
if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) { |
|
|
|
|
|
(void)manager->Replace(output, bn_training_update_outputs[index]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return bn_training_update_outputs[0]; |
|
|
} |
|
|
} |
|
|
} // namespace opt |
|
|
} // namespace opt |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |