/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #include #include #include #include "optimizer/optimizer.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" #include "ir/visitor.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} class MultiplyByZeroOrOne : public AnfVisitor { public: MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} ~MultiplyByZeroOrOne() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimScalarMul)(node); if (is_zero_) { return NewValueNode(zero_); } if (is_one_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (is_one_ || node->isa()) { x_ = node; return; } AnfVisitor::Visit(node); if (!is_one_) { x_ = node; } } void Visit(const ValueNodePtr &vnode) override { auto value = vnode->value(); if (*value == *zero_) { is_zero_ = true; } else if (*value == *one_) { is_one_ = true; } } void Reset() { x_ = nullptr; is_one_ = false; is_zero_ = false; } private: bool is_zero_{false}, is_one_{false}; ValuePtr zero_, one_; AnfNodePtr x_{nullptr}; }; // {prim::kPrimScalarAdd, X, 0} // {prim::kPrimScalarAdd, 0, X} class AddByZero : public AnfVisitor { public: AddByZero() : zero_(MakeValue(0)) {} ~AddByZero() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimScalarAdd)(node); if (is_zero_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (node->isa() && *GetValueNode(node) == *zero_) { is_zero_ = true; return; } x_ = node; } void Reset() { x_ = nullptr; is_zero_ = false; } private: bool is_zero_{false}; ValuePtr zero_; AnfNodePtr x_{nullptr}; }; // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} class TensorAddByZero : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTensorAdd)(node); if (is_zero_) { return x_; } return nullptr; } void Visit(const AnfNodePtr &node) override { if (IsPrimitive(node, prim::kPrimZerosLike)) { is_zero_ = true; return; } x_ = node; } void Reset() { x_ = nullptr; is_zero_ = false; } private: bool is_zero_{false}; AnfNodePtr x_{nullptr}; }; // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} class OptUpdateZeroTensor : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { return nullptr; } // {PrimMomentum, {...}, Y, Z, Xs} auto &inputs = node->cast()->inputs(); if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { return nullptr; } auto y = inputs[2]; auto z = inputs[3]; // {kPrimZerosLike, X} if (inputs[1]->cast()->size() != 2) { return nullptr; } // {prim::kPrimMakeTuple, Z, Y} return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); } }; // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} class ConstantDuplicateMul : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); // {prim::kPrimMul, Tensor1, {...}} AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); if (vnode_ == nullptr || cnode_ == nullptr) { return nullptr; } auto tensor1 = vnode_; auto mul = cnode_; Reset(); // {prim::kPrimMul, Tensor2, {...}} AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); if (vnode_ == nullptr || cnode_ == nullptr) { return nullptr; } auto tensor2 = vnode_; auto cnode = cnode_; auto PrimMul = GetValueNode(mul->input(0)); auto fg = node->func_graph(); auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); } void Visit(const AnfNodePtr &node) override { if (IsValueNode(node)) { vnode_ = node; } if (IsCNode(node)) { cnode_ = node->cast(); } } void Reset() { vnode_ = nullptr; cnode_ = nullptr; } private: AnfNodePtr vnode_; CNodePtr cnode_; }; // grad = AllReduce(grad) / worker_number // grad = grad + weight * decy // -> // grad = grad + weight * decy // grad = AllReduce(grad) / worker_number // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} class AdjustAllReduceMulAdd : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); // {prim::kPrimAddN, Zs} if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { return nullptr; } auto addn = node->cast(); if (addn->size() != 2) { return nullptr; } AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { return nullptr; } auto addn_maketuple = addn->input(1); auto fg = all_reduce_fg_; // addn inputs cross the graph, make the inputs same as allreduce node. if (z_->isa() && fg != z_->func_graph()) { auto cnode_z = z_->cast(); z_ = NewCNode(cnode_z->inputs(), fg); } auto addn_op_node = addn->input(0); auto make_tuple_op_node = addn->input(1)->cast()->input(0); AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); ProcessDependEdge(fg, addn_maketuple, all_reduce); return mul; } void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { // If has dynamic loss scale. auto &users_map = fg->manager()->node_users(); auto it = users_map.find(mul_cnode_); if (it != users_map.end()) { auto users = it->second; for (auto &user_pair : users) { auto node = user_pair.first; if (node != addn_maketuple) { if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { fg->manager()->SetEdge(node, user_pair.second, new_node); } } } } } void Visit(const AnfNodePtr &node) override { if (level_ == 0) { level_ = 1; is_reduce_match_ = false; // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} AnfVisitor::Match(prim::kPrimMul)(node); level_ = 0; if (is_reduce_match_) { mul_ = node->cast()->input(0); mul_cnode_ = node->cast(); y_ = tmp_; } else { z_ = node; } } if (level_ == 1) { // {prim::kPrimAllReduce, X} if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { auto cnode = node->cast(); if (cnode->size() > 1) { all_reduce_ = cnode->input(0); x_ = cnode->input(1); is_reduce_match_ = true; all_reduce_fg_ = cnode->func_graph(); } } else { tmp_ = node; } } } void Reset() { level_ = 0; is_reduce_match_ = false; x_ = nullptr; y_ = nullptr; z_ = nullptr; tmp_ = nullptr; all_reduce_fg_ = nullptr; } private: int level_{0}; bool is_reduce_match_{false}; AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; FuncGraphPtr all_reduce_fg_{nullptr}; }; class ArithmeticSimplify { public: ArithmeticSimplify() : multiply_by_zero_or_one_(), add_by_zero_(), tensor_add_by_zero_(), identity_(prim::kPrimIdentity), opt_update_zero_tensor_(), constant_duplicate_mul_() { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(constant_duplicate_mul_); } ~ArithmeticSimplify() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { new_node = eliminater(optimizer, node); if (new_node != nullptr) { return new_node; } } return nullptr; } private: MultiplyByZeroOrOne multiply_by_zero_or_one_; AddByZero add_by_zero_; TensorAddByZero tensor_add_by_zero_; PrimEliminater identity_; OptUpdateZeroTensor opt_update_zero_tensor_; ConstantDuplicateMul constant_duplicate_mul_; std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_