/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #include #include #include #include "optimizer/optimizer.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" #include "ir/visitor.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} class MultiplyByZeroOrOne : public AnfVisitor { public: MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} ~MultiplyByZeroOrOne() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimScalarMul)(node); if (is_zero_) { return NewValueNode(zero_); } if (is_one_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (is_one_ || node->isa()) { x_ = node; return; } AnfVisitor::Visit(node); if (!is_one_) { x_ = node; } } void Visit(const ValueNodePtr &vnode) override { auto value = vnode->value(); if (*value == *zero_) { is_zero_ = true; } else if (*value == *one_) { is_one_ = true; } } void Reset() { x_ = nullptr; is_one_ = false; is_zero_ = false; } private: bool is_zero_{false}, is_one_{false}; ValuePtr zero_, one_; AnfNodePtr x_{nullptr}; }; // {prim::kPrimScalarAdd, X, 0} // {prim::kPrimScalarAdd, 0, X} class AddByZero : public AnfVisitor { public: AddByZero() : zero_(MakeValue(0)) {} ~AddByZero() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimScalarAdd)(node); if (is_zero_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (node->isa() && *GetValueNode(node) == *zero_) { is_zero_ = true; return; } x_ = node; } void Reset() { x_ = nullptr; is_zero_ = false; } private: bool is_zero_{false}; ValuePtr zero_; AnfNodePtr x_{nullptr}; }; // {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X}, // {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}} class TensorAddByZero : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTensorAdd)(node); if (is_zero_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) { is_zero_ = true; return; } x_ = node; } void Reset() { x_ = nullptr; is_zero_ = false; } private: bool is_zero_{false}; AnfNodePtr x_{nullptr}; }; // {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} class OptUpdateZeroTensor : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { return nullptr; } // {PrimMomentum, {...}, Y, Z, Xs} auto &inputs = node->cast()->inputs(); if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) { return nullptr; } auto y = inputs[2]; auto z = inputs[3]; // {PrimZerosLikeTensor, X} if (inputs[1]->cast()->size() != 2) { return nullptr; } // {prim::kPrimMakeTuple, Z, Y} return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); } }; // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} class ConstantDuplicateMul : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); // {prim::kPrimMul, Tensor1, {...}} AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); if (vnode_ == nullptr || cnode_ == nullptr) { return nullptr; } auto tensor1 = vnode_; auto mul = cnode_; Reset(); // {prim::kPrimMul, Tensor2, {...}} AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); if (vnode_ == nullptr || cnode_ == nullptr) { return nullptr; } auto tensor2 = vnode_; auto cnode = cnode_; auto PrimMul = GetValueNode(mul->input(0)); auto fg = node->func_graph(); auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); } void Visit(const AnfNodePtr &node) override { if (IsValueNode(node)) { vnode_ = node; } if (IsCNode(node)) { cnode_ = node->cast(); } } void Reset() { vnode_ = nullptr; cnode_ = nullptr; } private: AnfNodePtr vnode_; CNodePtr cnode_; }; class ArithmeticSimplify { public: ArithmeticSimplify() : multiply_by_zero_or_one_(), add_by_zero_(), tensor_add_by_zero_(), identity_(prim::kPrimIdentity), opt_update_zero_tensor_(), constant_duplicate_mul_() { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(constant_duplicate_mul_); } ~ArithmeticSimplify() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { new_node = eliminater(optimizer, node); if (new_node != nullptr) { return new_node; } } return nullptr; } private: MultiplyByZeroOrOne multiply_by_zero_or_one_; AddByZero add_by_zero_; TensorAddByZero tensor_add_by_zero_; PrimEliminater identity_; OptUpdateZeroTensor opt_update_zero_tensor_; ConstantDuplicateMul constant_duplicate_mul_; std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_