Merge pull request !6853 from 张清华/master2tags/v1.1.0
| @@ -28,6 +28,7 @@ | |||||
| #include "frontend/optimizer/irpass/item_tuple_eliminate.h" | #include "frontend/optimizer/irpass/item_tuple_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/mark_interface_fusion.h" | #include "frontend/optimizer/irpass/mark_interface_fusion.h" | ||||
| #include "frontend/optimizer/irpass/merge_addn.h" | #include "frontend/optimizer/irpass/merge_addn.h" | ||||
| #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/minmax_grad.h" | #include "frontend/optimizer/irpass/minmax_grad.h" | ||||
| #include "frontend/optimizer/irpass/param_replace.h" | #include "frontend/optimizer/irpass/param_replace.h" | ||||
| #include "frontend/optimizer/irpass/partial_eliminate.h" | #include "frontend/optimizer/irpass/partial_eliminate.h" | ||||
| @@ -129,6 +130,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | ||||
| addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | ||||
| // AccumulateNV2 | |||||
| accumulaten_eliminater_ = | |||||
| MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2); | |||||
| // inline | // inline | ||||
| inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | ||||
| inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph); | inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph); | ||||
| @@ -77,6 +77,9 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr merge_addn_; | SubstitutionPtr merge_addn_; | ||||
| SubstitutionPtr addn_zero_filter_; | SubstitutionPtr addn_zero_filter_; | ||||
| // AccumulateNV2 | |||||
| SubstitutionPtr accumulaten_eliminater_; | |||||
| // Gradient irpasses | // Gradient irpasses | ||||
| SubstitutionPtr expand_jprim_; | SubstitutionPtr expand_jprim_; | ||||
| SubstitutionPtr minmaximum_grad_; | SubstitutionPtr minmaximum_grad_; | ||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {PrimAccumulateNV2, {kPrimMakeTuple, inputs}} | |||||
| class AccumulateNV2Eliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimAccumulateNV2, {IsCNode})(node); | |||||
| if (inputs_.empty() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| // If only two filtered inputs nodes, as {make_tuple, x}, return x. | |||||
| if (inputs_.size() == 2) { | |||||
| return inputs_[1]; | |||||
| } | |||||
| // If only one filtered node, all inputs nodes are zerolike, return one of the input. | |||||
| if (inputs_.size() == 1 && args_.size() > 0) { | |||||
| return args_[0]; | |||||
| } | |||||
| if (!has_zero_like_) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto accumulaten = NewValueNode(GetValueNode(cnode->input(0))); | |||||
| auto fg = node->func_graph(); | |||||
| auto make_tuple = fg->NewCNode(inputs_); | |||||
| return fg->NewCNode({accumulaten, make_tuple}); | |||||
| } | |||||
| void Visit(const CNodePtr &cnode) override { | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { | |||||
| return; | |||||
| } | |||||
| auto &inputs = cnode->inputs(); | |||||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); | |||||
| // {kPrimMakeTuple, X1, X2, ...} | |||||
| inputs_.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| for (auto &x : args_) { | |||||
| if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { | |||||
| inputs_.push_back(x); | |||||
| } else { | |||||
| has_zero_like_ = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| args_.clear(); | |||||
| inputs_.clear(); | |||||
| has_zero_like_ = false; | |||||
| } | |||||
| private: | |||||
| std::vector<AnfNodePtr> inputs_{}, args_{}; | |||||
| bool has_zero_like_{false}; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_ | |||||
| @@ -113,6 +113,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.arithmetic_simplify_, | irpass.arithmetic_simplify_, | ||||
| irpass.addn_zero_filter_, | irpass.addn_zero_filter_, | ||||
| irpass.adjust_all_reduce_mul_add_, | irpass.adjust_all_reduce_mul_add_, | ||||
| irpass.accumulaten_eliminater_, | |||||
| // Safe inlining | // Safe inlining | ||||
| irpass.inline_, | irpass.inline_, | ||||
| @@ -98,6 +98,7 @@ inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("Conca | |||||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | ||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | |||||
| inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | ||||
| inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | ||||
| inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | ||||
| @@ -893,6 +893,13 @@ class AccumulateNV2(PrimitiveWithInfer): | |||||
| self.__setattr_flag__ = True | self.__setattr_flag__ = True | ||||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | ||||
| def check_elim(self, inputs): | |||||
| if len(inputs) != 1: | |||||
| return (False, None) | |||||
| if isinstance(inputs[0], Tensor): | |||||
| return (True, inputs[0]) | |||||
| raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0]))) | |||||
| def infer_shape(self, inputs): | def infer_shape(self, inputs): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | ||||