/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/optimizer/irpass/arithmetic_simplify.h" namespace mindspore { namespace opt { namespace irpass { AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { PatternNode x, y, z, xs; PConstant one_(node, false, 1); PConstant one_scalar_(node, false, 1, true); PConstant zero_(node, false, 0); PConstant zero_scalar_(node, false, 0, true); PConstant const_(node); PConstant const_2(node); PConstant any_const(node); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { MATCH_REPLACE(node, x + zero_, x); // Add by zero MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one // Scalar Mul by zero MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue()); } // Prim Eliminate (identity) MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return nullptr; } // ConstantDuplicateMul auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node)); auto mul_node = node->cast()->inputs()[0]; if (new_mul_tensor == nullptr) { auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph()); return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph()); } auto new_cnode = NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph()); new_cnode->set_abstract(node->abstract()); return new_cnode; }; MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda); if (node->func_graph() == nullptr) { return nullptr; } // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y} MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0), PPrimitive(prim::kPrimMakeTuple, z, y)); // PowerOneEliminate MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x, one_scalar_.CheckFunc(IsValueNode, node)); return nullptr; } AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return nullptr; } PatternNode x, y; PConstant zero_(node, false, 0); // Multiply by zero MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node), !zero_.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); auto zero_prim = PPrimitive(prim::kPrimZerosLike, y); MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node), !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); return nullptr; } // grad = AllReduce(grad) / worker_number // grad = grad + weight * decy // -> // grad = grad + weight * decy // grad = AllReduce(grad) / worker_number // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { PatternNode x, y, z; auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x); auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true); auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true); auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat); auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { auto fg = all_reduce_pat.GetFuncGraph(); auto z_ = z.GetNode(node); auto x_ = x.GetNode(node); // If addn inputs cross the graph, make the inputs same as allreduce node. if (z_->isa() && fg != z_->func_graph()) { auto cnode_z = z_->cast(); z_ = NewCNode(cnode_z->inputs(), fg); } auto addn_cnode = addn_pat.GetOriginalNode()->cast(); auto addn_op_node = addn_cnode->input(0); auto make_tuple_op_node = addn_cnode->input(1)->cast()->input(0); auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast()->input(0); mul_cnode_ = mul_pat.GetOriginalNode(); auto mul_prim = mul_cnode_->cast()->input(0); auto addn_maketuple = admktup_pat.GetOriginalNode(); ShapeVector x_shape, z_shape; if (!x_->isa()) { if ((x_->abstract() == nullptr) || !x_->abstract()->isa()) { return nullptr; } auto x_abstract = x_->abstract()->cast(); x_shape = x_abstract->shape()->shape(); } else { ValuePtr x_value = x_->cast()->value(); if (!x_value->isa()) { return nullptr; } auto x_tensor = GetValueNode(x_->cast()); x_shape = x_tensor->shape(); } if (!z_->isa()) { if ((z_->abstract() == nullptr) || !z_->abstract()->isa()) { return nullptr; } auto z_abstract = z_->abstract()->cast(); z_shape = z_abstract->shape()->shape(); } else { ValuePtr z_value = z_->cast()->value(); if (!z_value->isa()) { return nullptr; } auto z_tensor = GetValueNode(z_->cast()); z_shape = z_tensor->shape(); } if (x_shape != z_shape) { // AddN requires x_ and z_ have the same shape. // If broadcasting TensorAdd is supported then can use this // AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimTensorAdd), z_, x_}, fg); return nullptr; } AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); ProcessDependEdge(fg, addn_maketuple, all_reduce); return mul; }; MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda); return nullptr; } void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { // If has dynamic loss scale. auto &users_map = fg->manager()->node_users(); auto it = users_map.find(mul_cnode_); if (it != users_map.end()) { auto users = it->second; for (auto &user_pair : users) { auto node = user_pair.first; if (node != addn_maketuple) { if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { fg->manager()->SetEdge(node, user_pair.second, new_node); } } } } } } // namespace irpass } // namespace opt } // namespace mindspore