Merge pull request !3270 from thlinh/dev_Jul17_removeSelecttags/v0.7.0-beta
| @@ -41,6 +41,7 @@ | |||||
| #include "frontend/optimizer/irpass/symbol_resolver.h" | #include "frontend/optimizer/irpass/symbol_resolver.h" | ||||
| #include "frontend/optimizer/irpass/tile_eliminate.h" | #include "frontend/optimizer/irpass/tile_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/transpose_eliminate.h" | #include "frontend/optimizer/irpass/transpose_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/value_based_eliminate.h" | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/optimizer/irpass/indexed_slices_eliminate.h" | #include "frontend/optimizer/irpass/indexed_slices_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | ||||
| @@ -165,6 +166,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| sparse_tensor_eliminate_ = MakeSubstitution( | sparse_tensor_eliminate_ = MakeSubstitution( | ||||
| std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate", | std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate", | ||||
| {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); | {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); | ||||
| // Value_Based Eliminate | |||||
| value_based_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", {prim::kPrimSelect}); | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| @@ -110,6 +110,9 @@ class OptimizeIRPassLib { | |||||
| // SparseTensor Eliminate | // SparseTensor Eliminate | ||||
| SubstitutionPtr sparse_tensor_eliminate_; | SubstitutionPtr sparse_tensor_eliminate_; | ||||
| // Value_Based Eliminate | |||||
| SubstitutionPtr value_based_eliminate_; | |||||
| }; | }; | ||||
| // the collection of irpass for resolve action | // the collection of irpass for resolve action | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/optimizer/irpass/value_based_eliminate.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||||
| return IsCNodePositive(node->cast<CNodePtr>()->input(1)); | |||||
| } | |||||
| if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| PatternNode x, y, z; | |||||
| PConstant zero_(node, false, 0); | |||||
| PConstant zero_scalar_(node, false, 0, true); | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y, | |||||
| IsCNodePositive(x.GetNode(node))); | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, | |||||
| IsCNodePositive(x.GetNode(node))); | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/irpass/prim_eliminate.h" | |||||
| #include "frontend/optimizer/optimizer_caller.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| #include "ir/pattern_matcher.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0 | |||||
| class ValueBasedEliminate : public OptimizerCaller { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_VALUE_BASED_ELIMINATE_H_ | |||||
| @@ -162,15 +162,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| } | } | ||||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig b_1 = opt::OptPassConfig({ | |||||
| irpass.zero_like_fill_zero_, | |||||
| irpass.item_tuple_eliminate_, | |||||
| irpass.float_tuple_getitem_switch_, | |||||
| irpass.reset_defer_inline_, | |||||
| irpass.inline_, | |||||
| irpass.special_op_eliminate_, | |||||
| irpass.get_make_ref_eliminate_, | |||||
| }); | |||||
| opt::OptPassConfig b_1 = | |||||
| opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | |||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, | |||||
| irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); | |||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| @@ -22,14 +22,15 @@ import os | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn.optim import AdamWeightDecay | from mindspore.nn.optim import AdamWeightDecay | ||||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | from mindspore.train.loss_scale_manager import DynamicLossScaleManager | ||||
| from mindspore.nn import learning_rate_schedule as lr_schedules | from mindspore.nn import learning_rate_schedule as lr_schedules | ||||
| from mindspore.ops import operations as P | |||||
| from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | ||||
| from ...dataset_mock import MindData | from ...dataset_mock import MindData | ||||
| from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph | from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph | ||||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" | _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data" | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||