/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "ir/anf.h" #include "pybind_api/ir/primitive_py.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/parse.h" #include "frontend/optimizer/ad/dfunctor.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "utils/symbolic.h" #include "utils/primitive_utils.h" #include "utils/ms_context.h" #include "utils/info.h" #include "debug/trace.h" namespace mindspore { namespace ad { using PatternListType = std::initializer_list; KPrim g_k_prims; FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { // Set a child scope named "grad'PrimitiveName'" for the bprop function, // and add "Gradients" to the front. static const std::string gradients_scope = "Gradients/"; static const std::string grad_op_child_scope_prefix = "/grad"; MS_EXCEPTION_IF_NULL(prim); auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); py::function fn; if (prim->is_base()) { fn = GetBpropFunction(prim->name()); } else { fn = prim->cast()->GetBpropFunction(); if (py::isinstance(fn)) { fn = GetBpropFunction(prim->name()); } } if (!fn || py::isinstance(fn)) { MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; return nullptr; } FuncGraphPtr func_graph = parse::ParsePythonCode(fn); if (func_graph == nullptr) { MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; return nullptr; } auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP); if (bprop_flag) { func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); } return func_graph; } FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; std::string func_name = "_fprop_" + prim->name(); py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); auto func_graph = parse::ParsePythonCode(fn); MS_EXCEPTION_IF_NULL(func_graph); return BasicClone(func_graph); } MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); auto iter = bprop_registry_meta_.find(prim); if (iter != bprop_registry_meta_.end()) { return iter->second; } if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); bprop_registry_meta_[prim::kPrimMakeTuple] = meta; return meta; } if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { MetaFuncGraphPtr meta = std::make_shared("make_list_gradient"); bprop_registry_meta_[prim::kPrimMakeList] = meta; return meta; } MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) { const auto &output = bprop_fg->output(); MS_EXCEPTION_IF_NULL(output); auto output_cnode = output->cast(); if (output_cnode != nullptr) { // If output_cnode has the form like (make_tuple, x, y). output_cnode->add_input(monad); return; } // If output is an empty tuple, create a (make_tuple, monad) as the new output. auto make_tuple = NewValueNode(prim::kPrimMakeTuple); output_cnode = bprop_fg->NewCNode({make_tuple, monad}); bprop_fg->set_output(output_cnode); } // Append U or/and IO monad to output of Bprop funcgraph. static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) { auto effect_info = GetPrimEffectInfo(prim); if (effect_info.memory) { MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString(); auto u = NewValueNode(kUMonad); u->set_abstract(kUMonad->ToAbstract()); AppendMonadOutput(bprop_fg, u); } if (effect_info.io) { MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString(); auto io = NewValueNode(kIOMonad); io->set_abstract(kIOMonad->ToAbstract()); AppendMonadOutput(bprop_fg, io); } } FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { if (!IsValueNode(value_node)) { MS_LOG(EXCEPTION) << "Primitive node is not valid."; } auto prim = GetValueNode(value_node); if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { auto fprop = GetFprop(prim); fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); return fprop; } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { return nullptr; } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { return nullptr; } FuncGraphPtr bprop_fg = nullptr; if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { if (MsContext::GetInstance()->get_param(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) { MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode."; } bprop_fg = BpropCut(value_node, resources); } else { auto iter = bprop_registry_.find(prim); if (iter != bprop_registry_.end()) { bprop_fg = iter->second; } if (bprop_fg == nullptr) { bprop_fg = GetBprop(prim); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; } else { bprop_fg = FakeBprop(value_node, resources); } } } AdjustForAutoMonad(prim, bprop_fg); auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << prim->name() << " prim bprop function to J expanded func graph. NodeInfo: " << trace::GetDebugInfo(bprop_fg->debug_info()); } return expanded_fg; } AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { // current_primal_fg may have extra parameters like u_monad, io_monad std::vector extra_args; // caller had checked size() - 2 is greater than 0. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) { auto current_primal_fg_param_size = current_primal_fg->parameters().size(); MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so " "Insert it. Extra parameters size: " << current_primal_fg_param_size - bprop_fg_param_size; for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) { const auto &primal_node = current_primal_fg->parameters()[i]; auto extra_node = bprop_fg->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), primal_node}); extra_args.push_back(extra_node); MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString(); } } // bprop_fg has been checked in caller if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { // Set bprop output as (env, dx, dy, dz, ...) auto cbprop = bprop_fg->output()->cast(); auto &inputs = cbprop->inputs(); std::vector args; args.push_back(NewValueNode(prim::kPrimMakeTuple)); args.push_back(NewValueNode(newenv)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); if (!extra_args.empty()) { args.insert(args.end(), extra_args.cbegin(), extra_args.cend()); } return NewCNode(args, bprop_fg); } // Set bprop output as (env, dx) std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); std::string python_ops("_tuple_add"); auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name)); if (!extra_args.empty()) { extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple)); auto extra_tuple = NewCNode(extra_args, bprop_fg); auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg); return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg); } return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg); } static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, std::vector *const transf_args) { // bprop_fg has been checked in caller // transform except the last 2 parameters: out, dout. auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; for (size_t i = 0; i < bprop_fg_param_size; ++i) { auto p = bprop_fg->parameters()[i]; MS_EXCEPTION_IF_NULL(p); TraceGuard trace_guard(std::make_shared(p->debug_info())); auto transf_p = outer->add_parameter(); (void)mng->Replace(p, transf_p); transf_args->push_back(transf_p); } } void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const PrimitivePtr &primitive, const FuncGraphPtr &outer, std::vector *const transf_args) { MS_EXCEPTION_IF_NULL(mng); TransformNormalArgs(mng, bprop_fg, outer, transf_args); // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter. auto effect_info = GetPrimEffectInfo(primitive); if (effect_info.memory) { MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString(); auto transf_p = outer->add_parameter(); transf_args->push_back(transf_p); } if (effect_info.io) { MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString(); auto transf_p = outer->add_parameter(); transf_args->push_back(transf_p); } } template void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const T ¤t_primal_fg, const FuncGraphPtr &outer, std::vector *const transf_args) { MS_EXCEPTION_IF_NULL(mng); TransformNormalArgs(mng, bprop_fg, outer, transf_args); auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; // current_primal_fg may have extra parameters after AutoMonad const auto ¤t_primal_fg_params = current_primal_fg->parameters(); if (bprop_fg_param_size < current_primal_fg_params.size()) { for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) { auto p = current_primal_fg_params[i]; MS_EXCEPTION_IF_NULL(p); // extra parameters should be Monad. if (!HasAbstractMonad(p)) { continue; } MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString() << ", has extra monad parameter: " << p->DebugString() << ", abstract: " << p->abstract()->ToString(); TraceGuard trace_guard(std::make_shared(p->debug_info())); auto transf_p = outer->add_parameter(); (void)mng->Replace(p, transf_p); transf_args->push_back(transf_p); } } if (transf_args->size() != current_primal_fg_params.size()) { MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString() << ", The number of parameter of this primal function is " << current_primal_fg_params.size() << ", but the number of parameters of bprop is " << bprop_fg_param_size; } } void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool check_bprop_flag = context->get_param(MS_CTX_CHECK_BPROP_FLAG); // Skip checking if check_bprop not set if (!check_bprop_flag) { return; } // bprop_fg has been checked in caller auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); MS_EXCEPTION_IF_NULL(check_bprop_class); auto check_bprop = bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); std::vector inputs; inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); AnfNodePtr params = bprop_fg->NewCNode(inputs); inputs.clear(); inputs.push_back(check_bprop); inputs.push_back(bprop_fg->output()); inputs.push_back(params); AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); bprop_fg->set_output(bprop_out); } FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { MS_EXCEPTION_IF_NULL(bprop_fg); // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph. // current_primal_fg is specalized and AutoMoaded primal_fg; auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph(); auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString() << " Cell bprop function to K expanded func graph. NodeInfo: " << trace::GetDebugInfo(primal_fg->debug_info()); } return expanded_fg; } FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { auto prim = GetValueNode(value_node); MS_EXCEPTION_IF_NULL(prim); auto &node_users = resources->manager()->node_users(); auto &users = node_users[value_node]; auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find cnode."; } auto inputs_num = cnode->first->cast()->size() - 1; auto func_graph = std::make_shared(); std::vector outputs; auto bprop_cut = std::make_shared("bprop_cut", py::object()); bprop_cut->CopyHookFunction(prim); auto cell_id = GetValue(prim->GetAttr("cell_id")); if (cell_id != "") { (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); } outputs.push_back(NewValueNode(bprop_cut)); for (size_t i = 0; i < inputs_num; ++i) { auto param = func_graph->add_parameter(); outputs.push_back(param); } auto p1 = func_graph->add_parameter(); auto p2 = func_graph->add_parameter(); outputs.push_back(p1); outputs.push_back(p2); func_graph->set_output(func_graph->NewCNode(outputs)); return func_graph; } FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { auto prim = value_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); auto &node_users = resources->manager()->node_users(); auto &users = node_users[value_node]; auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } auto inputs_num = cnode->first->cast()->inputs().size() - 1; auto effect_info = GetPrimEffectInfo(prim); // Don't add U or IO monad parameters as it will be added later. size_t monad_params_size = 0; if (effect_info.memory) { monad_params_size++; } if (effect_info.io) { monad_params_size++; } if (inputs_num < monad_params_size) { MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size << ", but the CNode is: " << cnode->first->DebugString(); } inputs_num -= monad_params_size; auto func_graph = std::make_shared(); std::vector outputs; outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); auto fake_bprop = std::make_shared("fake_bprop"); (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); for (size_t i = 0; i < inputs_num; ++i) { // Mock params for inputs auto param = func_graph->add_parameter(); // Mock derivatives for each inputs outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); } // mock params for out and dout (void)func_graph->add_parameter(); (void)func_graph->add_parameter(); func_graph->set_output(func_graph->NewCNode(outputs)); return func_graph; } } // namespace ad } // namespace mindspore