| @@ -16,7 +16,6 @@ | |||||
| #include "backend/optimizer/ascend/ascend_backend_optimization.h" | #include "backend/optimizer/ascend/ascend_backend_optimization.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <set> | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/bn_split.h" | #include "backend/optimizer/ascend/ir_fission/bn_split.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" | #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" | ||||
| @@ -24,6 +23,7 @@ | |||||
| #include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" | #include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" | #include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" | #include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" | |||||
| #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | ||||
| #include "backend/optimizer/pass/communication_op_fusion.h" | #include "backend/optimizer/pass/communication_op_fusion.h" | ||||
| @@ -111,18 +111,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { | |||||
| MS_EXCEPTION_IF_NULL(ir_fusion_pm); | MS_EXCEPTION_IF_NULL(ir_fusion_pm); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>()); | ||||
| @@ -133,10 +124,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>()); | ||||
| @@ -146,6 +133,27 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond3>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond4>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>()); | ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRuleCond5>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); | |||||
| } | |||||
| void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||||
| MS_EXCEPTION_IF_NULL(ir_fusion_pm); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV3>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); | ||||
| @@ -153,15 +161,12 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ir_fusion_pm->AddPass(std::make_shared<DereluFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>()); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -265,9 +270,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | ||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| if (context_ptr->ir_fusion_flag()) { | |||||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | |||||
| } | |||||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | |||||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | |||||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); | ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); | ||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &old_node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| MS_EXCEPTION_IF_NULL(old_node); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMin->name())), input}; | |||||
| CNodePtr reduce_min = graph->NewCNode(inputs); | |||||
| reduce_min->set_scope(old_node->scope()); | |||||
| AnfAlgo::CopyNodeAttr(kAttrKeepDims, old_node, reduce_min); | |||||
| return reduce_min; | |||||
| } | |||||
| bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||||
| if (dtype != kNumberTypeFloat32) { | |||||
| MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!"; | |||||
| return false; | |||||
| } | |||||
| if (shape.size() == 0 || shape.size() == 1) { | |||||
| MS_LOG(INFO) << "ReduceMin's input shape size is " << shape.size() << ", no need optimize!"; | |||||
| return false; | |||||
| } | |||||
| if (axis.size() == 1) { | |||||
| MS_LOG(INFO) << "ReduceMin axis size is 1, no need optimize!"; | |||||
| return false; | |||||
| } | |||||
| int last_dim = SizeToInt(shape.size() - 1); | |||||
| if (std::find(axis.begin(), axis.end(), -1) == axis.end() && | |||||
| std::find(axis.begin(), axis.end(), last_dim) == axis.end()) { | |||||
| MS_LOG(INFO) << "Attribute of axis does not contain the last axis, not match!"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<int> CalFirstAxis(const std::vector<size_t> &shape, const std::vector<int> &axis) { | |||||
| std::vector<int> axis_fisrt; | |||||
| int last_dim = SizeToInt(shape.size() - 1); | |||||
| std::copy_if(axis.begin(), axis.end(), std::back_inserter(axis_fisrt), | |||||
| [&last_dim](int v) { return v != -1 && v != last_dim; }); | |||||
| int dim_size = SizeToInt(shape.size()); | |||||
| if (axis_fisrt.empty()) { | |||||
| for (int i = 0; i < dim_size - 1; ++i) { | |||||
| axis_fisrt.push_back(i); | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < axis_fisrt.size(); ++i) { | |||||
| if (axis_fisrt[i] < -dim_size || axis_fisrt[i] > dim_size - 1) { | |||||
| MS_LOG(EXCEPTION) << "The axis of ReduceMin verify failed, quit optimizing"; | |||||
| } | |||||
| if (axis_fisrt[i] < 0) { | |||||
| axis_fisrt[i] = dim_size + axis_fisrt[i]; | |||||
| } | |||||
| } | |||||
| return axis_fisrt; | |||||
| } | |||||
| std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::vector<int> &axis_first, | |||||
| bool keep_dims) { | |||||
| std::vector<size_t> shape_first; | |||||
| for (size_t item = 0; item < shape.size(); ++item) { | |||||
| if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) { | |||||
| if (keep_dims) { | |||||
| // If keep_dims is true, curretn dimesion set to 1 | |||||
| shape_first.push_back(1); | |||||
| } | |||||
| } else { | |||||
| // item is not in ConstValueAxis | |||||
| shape_first.push_back(shape[item]); | |||||
| } | |||||
| } | |||||
| return shape_first; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ReduceMinFission::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| return VectorRef({prim::kPrimReduceMin, X}); | |||||
| } | |||||
| const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| if (graph == nullptr || node == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| CheckCNodeInputSize(cnode, 2); | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||||
| auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) { | |||||
| MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { | |||||
| MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims); | |||||
| if (!NeedOptmize(dtype, shape, axis)) { | |||||
| MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString(); | |||||
| return nullptr; | |||||
| } | |||||
| // Create reduce_min1 | |||||
| CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode); | |||||
| std::vector<int> axis_fisrt = CalFirstAxis(shape, axis); | |||||
| std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1); | |||||
| // Create reduce_min2 | |||||
| CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode); | |||||
| reduce_min2->set_abstract(cnode->abstract()); | |||||
| std::vector<int> axis_last = {-1}; | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_last), reduce_min2); | |||||
| return reduce_min2; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ReduceMinFission : public PatternProcessPass { | |||||
| public: | |||||
| explicit ReduceMinFission(bool multigraph = true) : PatternProcessPass("reduce_min_fission", multigraph) {} | |||||
| ~ReduceMinFission() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_REDUCE_MIN_FISSION_H_ | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWOptReduceMinFission : public BackendCommon { | |||||
| public: | |||||
| TestHWOptReduceMinFission() : get_py_fun_("gtest_input.pre_activate.reduce_min_fission_test", true) {} | |||||
| ~TestHWOptReduceMinFission() override = default; | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| TEST_F(TestHWOptReduceMinFission, test_fission) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "before"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp{32, 32, 32, 32}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| args_spec_list.push_back(x_abstract); | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto split_fission = std::make_shared<opt::ReduceMinFission>(); | |||||
| pm->AddPass(split_fission); | |||||
| optimizer->AddPassManager(pm); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reduce_min_fission", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | |||||
| make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | |||||
| reduce_min = P.ReduceMin(keep_dims=False) | |||||
| reduce_min1 = Primitive('ReduceMin') | |||||
| reduce_min2 = Primitive('ReduceMin') | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_reduce_min_fission(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x): | |||||
| res = reduce_min(x, (2, 3)) | |||||
| return res | |||||
| @fns | |||||
| def after(x): | |||||
| res = reduce_min1(x) | |||||
| res = reduce_min2(res) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | |||||