Browse Source

!6853 Add AccumulaterNV2 eliminator pass.

Merge pull request !6853 from 张清华/master2
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1b07551bd4
6 changed files with 113 additions and 0 deletions
  1. +5
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +96
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/accumulaten_eliminate.h
  4. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  5. +1
    -0
      mindspore/core/base/core_ops.h
  6. +7
    -0
      mindspore/ops/operations/math_ops.py

+ 5
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -28,6 +28,7 @@
#include "frontend/optimizer/irpass/item_tuple_eliminate.h"
#include "frontend/optimizer/irpass/mark_interface_fusion.h"
#include "frontend/optimizer/irpass/merge_addn.h"
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
#include "frontend/optimizer/irpass/minmax_grad.h"
#include "frontend/optimizer/irpass/param_replace.h"
#include "frontend/optimizer/irpass/partial_eliminate.h"
@@ -129,6 +130,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);

// AccumulateNV2
accumulaten_eliminater_ =
MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);

// inline
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -77,6 +77,9 @@ class OptimizeIRPassLib {
SubstitutionPtr merge_addn_;
SubstitutionPtr addn_zero_filter_;

// AccumulateNV2
SubstitutionPtr accumulaten_eliminater_;

// Gradient irpasses
SubstitutionPtr expand_jprim_;
SubstitutionPtr minmaximum_grad_;


+ 96
- 0
mindspore/ccsrc/frontend/optimizer/irpass/accumulaten_eliminate.h View File

@@ -0,0 +1,96 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_

#include <vector>
#include <algorithm>
#include <memory>

#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
namespace irpass {
// {PrimAccumulateNV2, {kPrimMakeTuple, inputs}}
class AccumulateNV2Eliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimAccumulateNV2, {IsCNode})(node);

if (inputs_.empty() || node->func_graph() == nullptr) {
return nullptr;
}

// If only two filtered inputs nodes, as {make_tuple, x}, return x.
if (inputs_.size() == 2) {
return inputs_[1];
}

// If only one filtered node, all inputs nodes are zerolike, return one of the input.
if (inputs_.size() == 1 && args_.size() > 0) {
return args_[0];
}

if (!has_zero_like_) {
return nullptr;
}

auto cnode = node->cast<CNodePtr>();
auto accumulaten = NewValueNode(GetValueNode(cnode->input(0)));
auto fg = node->func_graph();
auto make_tuple = fg->NewCNode(inputs_);
return fg->NewCNode({accumulaten, make_tuple});
}

void Visit(const CNodePtr &cnode) override {
if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
return;
}

auto &inputs = cnode->inputs();
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));

// {kPrimMakeTuple, X1, X2, ...}
inputs_.push_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &x : args_) {
if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
inputs_.push_back(x);
} else {
has_zero_like_ = true;
}
}
}

void Reset() {
args_.clear();
inputs_.clear();
has_zero_like_ = false;
}

private:
std::vector<AnfNodePtr> inputs_{}, args_{};
bool has_zero_like_{false};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ACCUMULATEN_ELIMINATE_H_

+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -113,6 +113,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.arithmetic_simplify_,
irpass.addn_zero_filter_,
irpass.adjust_all_reduce_mul_add_,
irpass.accumulaten_eliminater_,

// Safe inlining
irpass.inline_,


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -98,6 +98,7 @@ inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("Conca
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");


+ 7
- 0
mindspore/ops/operations/math_ops.py View File

@@ -893,6 +893,13 @@ class AccumulateNV2(PrimitiveWithInfer):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])

def check_elim(self, inputs):
if len(inputs) != 1:
return (False, None)
if isinstance(inputs[0], Tensor):
return (True, inputs[0])
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))

def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)


Loading…
Cancel
Save