Merge pull request !1156 from YuJianfeng/mastertags/v0.3.0-alpha
| @@ -0,0 +1,169 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | |||
| #include <vector> | |||
| #include "pre_activate/common/helper.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kBatchNormGradInferOutputNum = 3; | |||
| bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (manager->node_users().find(node) == manager->node_users().end()) { | |||
| MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; | |||
| return false; | |||
| } | |||
| for (const auto &node_index : manager->node_users()[node]) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | |||
| auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| auto value_node = index_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| int index = GetValue<int>(value_node->value()); | |||
| if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { | |||
| MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| // Set inputs | |||
| auto iter_input0 = (*equiv).find(input0_var_); | |||
| if (iter_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; | |||
| } | |||
| auto iter_input2 = (*equiv).find(input2_var_); | |||
| if (iter_input2 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; | |||
| } | |||
| auto iter_input4 = (*equiv).find(input4_var_); | |||
| if (iter_input4 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; | |||
| } | |||
| std::vector<AnfNodePtr> bn_infer_grad_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNInferGradOpName)), utils::cast<AnfNodePtr>(iter_input0->second), | |||
| utils::cast<AnfNodePtr>(iter_input2->second), utils::cast<AnfNodePtr>(iter_input4->second)}; | |||
| auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_infer_grad); | |||
| // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. | |||
| auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract()); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); | |||
| if (bn_grad_abstract_tuple->elements().empty()) { | |||
| MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; | |||
| } | |||
| bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); | |||
| AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); | |||
| bn_infer_grad->set_scope(bn_grad->scope()); | |||
| return bn_infer_grad; | |||
| } | |||
| AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &bn_grad, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| // Set inputs | |||
| auto iter_input0 = (*equiv).find(input0_var_); | |||
| if (iter_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; | |||
| } | |||
| auto iter_input1 = (*equiv).find(input1_var_); | |||
| if (iter_input1 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; | |||
| } | |||
| auto iter_input3 = (*equiv).find(input3_var_); | |||
| if (iter_input3 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; | |||
| } | |||
| auto iter_input4 = (*equiv).find(input4_var_); | |||
| if (iter_input4 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_update_grad_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), | |||
| utils::cast<AnfNodePtr>(iter_input0->second), utils::cast<AnfNodePtr>(iter_input1->second), | |||
| utils::cast<AnfNodePtr>(iter_input3->second), utils::cast<AnfNodePtr>(iter_input4->second)}; | |||
| auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_training_update_grad); | |||
| // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. | |||
| auto bn_grad_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_grad->abstract()); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); | |||
| if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; | |||
| } | |||
| std::vector<AbstractBasePtr> abstract_list{bn_grad_abstract_tuple->elements()[1], | |||
| bn_grad_abstract_tuple->elements()[2]}; | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| bn_training_update_grad->set_abstract(abstract_tuple); | |||
| AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); | |||
| bn_training_update_grad->set_scope(bn_grad->scope()); | |||
| return bn_training_update_grad; | |||
| } | |||
| const BaseRef BatchNormGradInferFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); | |||
| } | |||
| const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast<CNodePtr>())) { | |||
| MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::GetNodeAttr<bool>(node, kAttrIsTraining)) { | |||
| MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; | |||
| return nullptr; | |||
| } | |||
| if (!CheckOutputsIndex(func_graph, node)) { | |||
| MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); | |||
| AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); | |||
| std::vector<AnfNodePtr> bn_training_update_grad_outputs; | |||
| CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, | |||
| &bn_training_update_grad_outputs); | |||
| if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " | |||
| << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, | |||
| bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; | |||
| auto make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| return make_tuple; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ | |||
| #include <memory> | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNormGradInferFission : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNormGradInferFission(bool multigraph = true) | |||
| : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), | |||
| input0_var_(std::make_shared<Var>()), | |||
| input1_var_(std::make_shared<Var>()), | |||
| input2_var_(std::make_shared<Var>()), | |||
| input3_var_(std::make_shared<Var>()), | |||
| input4_var_(std::make_shared<Var>()) {} | |||
| ~BatchNormGradInferFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; | |||
| AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, | |||
| const EquivPtr &equiv) const; | |||
| VarPtr input0_var_; | |||
| VarPtr input1_var_; | |||
| VarPtr input2_var_; | |||
| VarPtr input3_var_; | |||
| VarPtr input4_var_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ | |||
| @@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D | |||
| constexpr auto kLabelSetOpName = "LabelSet"; | |||
| constexpr auto kLabelSwitchOpName = "LabelSwitch"; | |||
| constexpr auto kLabelGotoOpName = "LabelGoto"; | |||
| constexpr auto kBNInferGradOpName = "BNInferGrad"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWBatchNormGradInferFission : public BackendCommon { | |||
| public: | |||
| TestHWBatchNormGradInferFission() | |||
| : get_py_fun_("gtest_input.pre_activate.batch_norm_grad_infer_fission_test", true) {} | |||
| ~TestHWBatchNormGradInferFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_is_training"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); | |||
| } | |||
| TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_infer_fission", "before_output3_not_null"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::BatchNormGradInferFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops import Primitive | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| BatchNormGradTraining = G.BatchNormGrad(is_training=True) | |||
| BatchNormGradInfer = G.BatchNormGrad(is_training=False) | |||
| BNInferGrad = Primitive('BNInferGrad') | |||
| BNTrainingUpdateGrad = Primitive('BNTrainingUpdateGrad') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_batch_norm_grad_infer_fission(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_is_training(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_output3_not_null(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4): | |||
| bn_infer_grad = BNInferGrad(input0, input2, input4) | |||
| bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4) | |||
| outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0), | |||
| tuple_getitem(bn_training_update_grad, 1)) | |||
| new_outputs = make_tuple(tuple_getitem(outputs, 0), tuple_getitem(outputs, 1), tuple_getitem(outputs, 2)) | |||
| output = tuple_getitem(new_outputs, 0) | |||
| return make_tuple(output) | |||
| return fns[tag] | |||