/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ #include #include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { // check if node is MinimumGrad() or MaximumGrad() bool IsOriginMaxMinGrad(const AnfNodePtr &node) { if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { return false; } auto cnode = node->cast(); auto prim = GetValueNode(cnode->input(0)); auto x_v = prim->GetAttr("grad_x"); auto y_v = prim->GetAttr("grad_y"); if (x_v == nullptr || y_v == nullptr || !x_v->isa() || !y_v->isa()) { return false; } bool x = GetValue(x_v); bool y = GetValue(y_v); return x && y; } } // namespace internal // {prim::kPrimTupleGetItem, {target_grad, Xs}, C} class MinMaximumGrad : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {internal::IsOriginMaxMinGrad, IsValueNode})(node); if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { return nullptr; } // check single use auto mng = optimizer->resource()->manager(); auto &users = mng->node_users(); if (users.find(grad_) == users.end() || users[grad_].size() != 1) { return nullptr; } // {target_grad, Xs} auto &inputs = grad_->inputs(); auto prim = GetValueNode(inputs[0]); auto new_prim = std::make_shared(prim->name()); new_prim->set_attr("grad_x", MakeValue(true)); new_prim->set_attr("grad_y", MakeValue(true)); if (idx_ == 0) { new_prim->set_attr("grad_y", MakeValue(false)); } if (idx_ == 1) { new_prim->set_attr("grad_x", MakeValue(false)); } std::vector args; args.push_back(NewValueNode(new_prim)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); auto fg = node->func_graph(); auto tuple = fg->NewCNode(args); return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple, NewValueNode(MakeValue(idx_))}); } void Visit(const CNodePtr &cnode) override { grad_ = cnode; } void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } void Reset() { idx_ = -1; grad_ = nullptr; } private: int idx_{-1}; CNodePtr grad_{nullptr}; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MINMAX_GRAD_H_