| @@ -97,6 +97,8 @@ | |||
| #include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" | |||
| #include "backend/optimizer/ascend/format_type/remove_internal_output.h" | |||
| #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/config_manager.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -153,6 +155,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | |||
| } | |||
| } // namespace | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index, | |||
| size_t offset) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(origin_concat_cnode); | |||
| std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||
| for (size_t i = begin_index; i < begin_index + offset; ++i) { | |||
| new_concat_inputs.push_back(origin_concat_cnode->input(i)); | |||
| } | |||
| CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_concat); | |||
| new_concat->set_scope(origin_concat_cnode->scope()); | |||
| // Set attrs | |||
| AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat); | |||
| AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(offset)), new_concat); | |||
| std::vector<int> dyn_input_sizes{SizeToInt(offset)}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat); | |||
| // infer shape | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0); | |||
| auto axis = AnfAlgo::GetNodeAttr<int>(origin_concat_cnode, kAttrAxis); | |||
| if (axis < 0) { | |||
| axis += input_shape.size(); | |||
| } | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0); | |||
| if (axis < 0 || axis >= SizeToInt(output_shape.size()) || axis >= SizeToInt(input_shape.size())) { | |||
| MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; | |||
| } | |||
| output_shape[axis] = input_shape[axis] * offset; | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape}, | |||
| new_concat.get()); | |||
| return new_concat; | |||
| } | |||
| } // namespace | |||
| const BaseRef ConcatFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimConcat, Xs}); | |||
| } | |||
| const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // The real input begins with index 1. | |||
| size_t origin_input_size = cnode->inputs().size() - 1; | |||
| if (origin_input_size <= inputs_divisor_) { | |||
| return nullptr; | |||
| } | |||
| CNodePtr new_cnode = cnode; | |||
| while (origin_input_size > inputs_divisor_) { | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||
| size_t cur_input_index = 1; | |||
| // Divide the inputs of concat by inputs_divisor_. | |||
| while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { | |||
| base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_)); | |||
| cur_input_index += inputs_divisor_; | |||
| } | |||
| for (size_t i = cur_input_index; i <= origin_input_size; i++) { | |||
| base_concat_inputs.push_back(new_cnode->input(i)); | |||
| } | |||
| CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(base_concat); | |||
| base_concat->set_scope(new_cnode->scope()); | |||
| base_concat->set_abstract(new_cnode->abstract()); | |||
| // Set attrs | |||
| AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat); | |||
| AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); | |||
| std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); | |||
| new_cnode = base_concat; | |||
| origin_input_size = base_concat->inputs().size() - 1; | |||
| } | |||
| return new_cnode; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kConcatInputsDivisor = 63; | |||
| class ConcatFission : public PatternProcessPass { | |||
| public: | |||
| explicit ConcatFission(bool multigraph = true) | |||
| : PatternProcessPass("concat_fission", multigraph), inputs_divisor_(kConcatInputsDivisor) {} | |||
| ~ConcatFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| size_t inputs_divisor_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_ | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index, | |||
| size_t offset) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(origin_pack_cnode); | |||
| std::vector<AnfNodePtr> new_pack_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimPack->name()))}; | |||
| for (size_t i = begin_index; i < begin_index + offset; ++i) { | |||
| new_pack_inputs.push_back(origin_pack_cnode->input(i)); | |||
| } | |||
| CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_pack); | |||
| new_pack->set_scope(origin_pack_cnode->scope()); | |||
| new_pack->set_abstract(origin_pack_cnode->abstract()); | |||
| AnfAlgo::CopyNodeAttr(kAttrAxis, origin_pack_cnode, new_pack); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_pack); | |||
| AnfAlgo::SetNodeAttr(kAttrNum, MakeValue(SizeToInt(offset)), new_pack); | |||
| std::vector<int> dyn_input_sizes{SizeToInt(offset)}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack); | |||
| // infer shape | |||
| auto output_shape = AnfAlgo ::GetOutputInferShape(origin_pack_cnode, 0); | |||
| auto axis = AnfAlgo::GetNodeAttr<int>(new_pack, kAttrAxis); | |||
| if (axis < 0) { | |||
| axis += output_shape.size(); | |||
| } | |||
| if (axis < 0) { | |||
| MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"; | |||
| } | |||
| std::vector<size_t> new_shape; | |||
| for (size_t i = 0; i < output_shape.size() + 1; ++i) { | |||
| if (i < IntToSize(axis)) { | |||
| new_shape.push_back(output_shape[i]); | |||
| } else if (i == IntToSize(axis)) { | |||
| new_shape.push_back(offset); | |||
| } else { | |||
| new_shape.push_back(output_shape[i - 1]); | |||
| } | |||
| } | |||
| new_shape.erase(new_shape.begin() + axis + 1); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape}, | |||
| new_pack.get()); | |||
| return new_pack; | |||
| } | |||
| } // namespace | |||
| const BaseRef PackFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimPack, Xs}); | |||
| } | |||
| const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // The real input begins with index 1. | |||
| size_t origin_input_size = cnode->inputs().size() - 1; | |||
| if (origin_input_size <= inputs_divisor_) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||
| size_t cur_input_index = 1; | |||
| // Divide the inputs of pack by inputs_divisor_. | |||
| while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { | |||
| base_concat_inputs.push_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_)); | |||
| cur_input_index += inputs_divisor_; | |||
| } | |||
| if (cur_input_index <= origin_input_size) { | |||
| base_concat_inputs.push_back( | |||
| CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1)); | |||
| } | |||
| CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); | |||
| MS_EXCEPTION_IF_NULL(base_concat); | |||
| base_concat->set_scope(cnode->scope()); | |||
| base_concat->set_abstract(cnode->abstract()); | |||
| AnfAlgo::CopyNodeAttr(kAttrAxis, cnode, base_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat); | |||
| std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat); | |||
| return base_concat; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kPackInputsDivisor = 63; | |||
| class PackFission : public PatternProcessPass { | |||
| public: | |||
| explicit PackFission(bool multigraph = true) | |||
| : PatternProcessPass("pack_fission", multigraph), inputs_divisor_(kPackInputsDivisor) {} | |||
| ~PackFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| size_t inputs_divisor_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_ | |||
| @@ -241,6 +241,9 @@ constexpr auto kAttrOffset = "offset"; | |||
| constexpr auto kAttrPsKey = "ps_key"; | |||
| constexpr auto kAttrOptimizerType = "optim_type"; | |||
| constexpr auto kAttrChildGraph = "child_graph"; | |||
| constexpr auto kAttrInputNums = "inputNums"; | |||
| constexpr auto kAttrT = "T"; | |||
| constexpr auto kAttrNum = "num"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -0,0 +1,160 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "backend/optimizer/ascend/ir_fission/concat_fission.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWConcatFission : public BackendCommon { | |||
| public: | |||
| TestHWConcatFission() : get_py_fun_("gtest_input.pre_activate.concat_fission_test", true) {} | |||
| ~TestHWConcatFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto concat_fission = std::make_shared<opt::ConcatFission>(); | |||
| concat_fission->inputs_divisor_ = 2; | |||
| pm->AddPass(concat_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| auto kg_after = GetKernelGraph(g_after, args_spec_list); | |||
| EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto concat_fission = std::make_shared<opt::ConcatFission>(); | |||
| concat_fission->inputs_divisor_ = 3; | |||
| pm->AddPass(concat_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| auto kg_after = GetKernelGraph(g_after, args_spec_list); | |||
| EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto concat_fission = std::make_shared<opt::ConcatFission>(); | |||
| concat_fission->inputs_divisor_ = 4; | |||
| pm->AddPass(concat_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| auto kg_after = GetKernelGraph(g_after, args_spec_list); | |||
| EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto concat_fission = std::make_shared<opt::ConcatFission>(); | |||
| concat_fission->inputs_divisor_ = 8; | |||
| pm->AddPass(concat_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| auto kg_after = GetKernelGraph(g_after, args_spec_list); | |||
| EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto concat_fission = std::make_shared<opt::ConcatFission>(); | |||
| concat_fission->inputs_divisor_ = 9; | |||
| pm->AddPass(concat_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_9"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| auto kg_after = GetKernelGraph(g_after, args_spec_list); | |||
| EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWPackFission : public BackendCommon { | |||
| public: | |||
| TestHWPackFission() : get_py_fun_("gtest_input.pre_activate.pack_fission_test", true) {} | |||
| ~TestHWPackFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWPackFission, test_pack_fission_divided_by_3) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pack_fission = std::make_shared<opt::PackFission>(); | |||
| pack_fission->inputs_divisor_ = 3; | |||
| pm->AddPass(pack_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_3"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWPackFission, test_pack_fission_divided_by_4) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 9; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pack_fission = std::make_shared<opt::PackFission>(); | |||
| pack_fission->inputs_divisor_ = 4; | |||
| pm->AddPass(pack_fission); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_4"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import operations as P | |||
| concat = P.Concat() | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_concat_fission(tag): | |||
| """ test_adam_apply_one_with_decay_rule """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) | |||
| @fns | |||
| def after_divided_by_2(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| a = concat((input0, input1)) | |||
| b = concat((input2, input3)) | |||
| c = concat((input4, input5)) | |||
| d = concat((input6, input7)) | |||
| f = concat((a, b)) | |||
| g = concat((c, d)) | |||
| i = concat((f, g)) | |||
| return concat((i, input8)) | |||
| @fns | |||
| def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| a = concat((input0, input1, input2)) | |||
| b = concat((input3, input4, input5)) | |||
| c = concat((input6, input7, input8)) | |||
| return concat((a, b, c)) | |||
| @fns | |||
| def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| a = concat((input0, input1, input2, input3)) | |||
| b = concat((input4, input5, input6, input7)) | |||
| return concat((a, b, input8)) | |||
| @fns | |||
| def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| a = concat((input0, input1, input2, input3, input4, input5, input6, input7)) | |||
| return concat((a, input8)) | |||
| @fns | |||
| def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8)) | |||
| return fns[tag] | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import Primitive | |||
| pack = P.Pack() | |||
| concat = P.Concat() | |||
| make_tuple = Primitive('make_tuple') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_pack_fission(tag): | |||
| """ test_adam_apply_one_with_decay_rule """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| return pack((input0, input1, input2, input3, input4, input5, input6, input7, input8)) | |||
| @fns | |||
| def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| pack1 = pack(input0, input1, input2) | |||
| pack2 = pack(input3, input4, input5) | |||
| pack3 = pack(input6, input7, input8) | |||
| return make_tuple(concat(pack1, pack2, pack3)) | |||
| @fns | |||
| def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): | |||
| pack1 = pack(input0, input1, input2, input3) | |||
| pack2 = pack(input4, input5, input6, input7) | |||
| pack3 = pack(input8) | |||
| return make_tuple(concat(pack1, pack2, pack3)) | |||
| return fns[tag] | |||